Skip to content

Commit 2045cca

Browse files
authored
[mlir][sparse] add a forwarding insertion to SparseTensorStorage (#68939)
1 parent c40902c commit 2045cca

File tree

9 files changed

+174
-129
lines changed

9 files changed

+174
-129
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,10 @@ constexpr bool isComplexPrimaryType(PrimaryType valTy) {
143143
/// The actions performed by @newSparseTensor.
144144
enum class Action : uint32_t {
145145
kEmpty = 0,
146-
// newSparseTensor no longer handles `kFromFile=1`, so we leave this
147-
// number reserved to help catch any code that still needs updating.
146+
kEmptyForward = 1,
148147
kFromCOO = 2,
149148
kSparseToSparse = 3,
150-
kEmptyCOO = 4,
149+
kFuture = 4, // not used
151150
kToCOO = 5,
152151
kToIterator = 6,
153152
kPack = 7,

mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h

Lines changed: 114 additions & 75 deletions
Large diffs are not rendered by default.

mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -37,22 +37,20 @@ extern "C" {
3737
//
3838
//===----------------------------------------------------------------------===//
3939

40-
/// The @newSparseTensor function for constructing a new sparse tensor.
4140
/// This is the "swiss army knife" method for materializing sparse
4241
/// tensors into the computation. The types of the `ptr` argument and
4342
/// the result depend on the action, as explained in the following table
4443
/// (where "STS" means a sparse-tensor-storage object, "COO" means
4544
/// a coordinate-scheme object, and "Iterator" means an iterator object).
4645
///
4746
/// Action: `ptr`: Returns:
48-
/// kEmpty unused STS, empty
49-
/// kEmptyCOO unused COO, empty
50-
/// kFromFile char* filename STS, read from the file
47+
/// kEmpty - STS, empty
48+
/// kEmptyForward - STS, empty, with forwarding COO
5149
/// kFromCOO COO STS, copied from the COO source
52-
/// kToCOO STS COO, copied from the STS source
5350
/// kSparseToSparse STS STS, copied from the STS source
54-
/// kToIterator STS Iterator, call @getNext to use and
55-
/// @delSparseTensorIterator to free.
51+
/// kToCOO STS COO, copied from the STS source
52+
/// kToIterator STS Iterator (@getNext/@delSparseTensorIterator)
53+
/// kPack buffers STS, from level buffers
5654
MLIR_CRUNNERUTILS_EXPORT void *_mlir_ciface_newSparseTensor( // NOLINT
5755
StridedMemRefType<index_type, 1> *dimSizesRef,
5856
StridedMemRefType<index_type, 1> *lvlSizesRef,
@@ -84,19 +82,15 @@ MLIR_SPARSETENSOR_FOREVERY_O(DECL_SPARSEPOSITIONS)
8482
MLIR_SPARSETENSOR_FOREVERY_O(DECL_SPARSECOORDINATES)
8583
#undef DECL_SPARSECOORDINATES
8684

87-
/// Coordinate-scheme method for adding a new element.
88-
/// TODO: remove dim2lvl
89-
#define DECL_ADDELT(VNAME, V) \
90-
MLIR_CRUNNERUTILS_EXPORT void *_mlir_ciface_addElt##VNAME( \
91-
void *lvlCOO, StridedMemRefType<V, 0> *vref, \
92-
StridedMemRefType<index_type, 1> *dimCoordsRef, \
93-
StridedMemRefType<index_type, 1> *dim2lvlRef);
94-
MLIR_SPARSETENSOR_FOREVERY_V(DECL_ADDELT)
95-
#undef DECL_ADDELT
85+
/// Tensor-storage method for a dim to lvl forwarding insertion.
86+
#define DECL_FORWARDINGINSERT(VNAME, V) \
87+
MLIR_CRUNNERUTILS_EXPORT void _mlir_ciface_forwardingInsert##VNAME( \
88+
void *tensor, StridedMemRefType<V, 0> *vref, \
89+
StridedMemRefType<index_type, 1> *dimCoordsRef); \
90+
MLIR_SPARSETENSOR_FOREVERY_V(DECL_FORWARDINGINSERT)
91+
#undef DECL_FORWARDINGINSERT
9692

9793
/// Coordinate-scheme method for getting the next element while iterating.
98-
/// The `cref` argument uses the same coordinate-space as the `iter` (which
99-
/// can be either dim- or lvl-coords, depending on context).
10094
#define DECL_GETNEXT(VNAME, V) \
10195
MLIR_CRUNNERUTILS_EXPORT bool _mlir_ciface_getNext##VNAME( \
10296
void *iter, StridedMemRefType<index_type, 1> *cref, \
@@ -185,8 +179,11 @@ MLIR_CRUNNERUTILS_EXPORT index_type sparseLvlSize(void *tensor, index_type l);
185179
/// Tensor-storage method to get the size of the given dimension.
186180
MLIR_CRUNNERUTILS_EXPORT index_type sparseDimSize(void *tensor, index_type d);
187181

182+
/// Tensor-storage method to finalize forwarding insertions.
183+
MLIR_CRUNNERUTILS_EXPORT void endForwardingInsert(void *tensor);
184+
188185
/// Tensor-storage method to finalize lexicographic insertions.
189-
MLIR_CRUNNERUTILS_EXPORT void endInsert(void *tensor);
186+
MLIR_CRUNNERUTILS_EXPORT void endLexInsert(void *tensor);
190187

191188
/// Coordinate-scheme method to write to file in extended FROSTT format.
192189
#define DECL_OUTSPARSETENSOR(VNAME, V) \

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,7 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
596596
ConversionPatternRewriter &rewriter) const override {
597597
if (op.getHasInserts()) {
598598
// Finalize any pending insertions.
599-
StringRef name = "endInsert";
599+
StringRef name = "endLexInsert";
600600
createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
601601
EmitCInterface::Off);
602602
}

mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,13 @@ MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETCOORDINATES)
8080
MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETVALUES)
8181
#undef IMPL_GETVALUES
8282

83+
#define IMPL_FORWARDINGINSERT(VNAME, V) \
84+
void SparseTensorStorageBase::forwardingInsert(const uint64_t *, V) { \
85+
FATAL_PIV("forwardingInsert" #VNAME); \
86+
}
87+
MLIR_SPARSETENSOR_FOREVERY_V(IMPL_FORWARDINGINSERT)
88+
#undef IMPL_FORWARDINGINSERT
89+
8390
#define IMPL_LEXINSERT(VNAME, V) \
8491
void SparseTensorStorageBase::lexInsert(const uint64_t *, V) { \
8592
FATAL_PIV("lexInsert" #VNAME); \

mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,16 @@ extern "C" {
177177
#define CASE(p, c, v, P, C, V) \
178178
if (posTp == (p) && crdTp == (c) && valTp == (v)) { \
179179
switch (action) { \
180-
case Action::kEmpty: \
180+
case Action::kEmpty: { \
181181
return SparseTensorStorage<P, C, V>::newEmpty( \
182-
dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim); \
182+
dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, \
183+
false); \
184+
} \
185+
case Action::kEmptyForward: { \
186+
return SparseTensorStorage<P, C, V>::newEmpty( \
187+
dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, \
188+
true); \
189+
} \
183190
case Action::kFromCOO: { \
184191
assert(ptr && "Received nullptr for SparseTensorCOO object"); \
185192
auto &coo = *static_cast<SparseTensorCOO<V> *>(ptr); \
@@ -193,8 +200,9 @@ extern "C" {
193200
dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, \
194201
dimRank, tensor); \
195202
} \
196-
case Action::kEmptyCOO: \
197-
return new SparseTensorCOO<V>(lvlRank, lvlSizes); \
203+
case Action::kFuture: { \
204+
break; \
205+
} \
198206
case Action::kToCOO: { \
199207
assert(ptr && "Received nullptr for SparseTensorStorage object"); \
200208
auto &tensor = *static_cast<SparseTensorStorage<P, C, V> *>(ptr); \
@@ -405,29 +413,20 @@ MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATES)
405413
#undef IMPL_SPARSECOORDINATES
406414
#undef IMPL_GETOVERHEAD
407415

408-
// TODO: use MapRef here for translation of coordinates
409-
// TODO: remove dim2lvl
410-
#define IMPL_ADDELT(VNAME, V) \
411-
void *_mlir_ciface_addElt##VNAME( \
412-
void *lvlCOO, StridedMemRefType<V, 0> *vref, \
413-
StridedMemRefType<index_type, 1> *dimCoordsRef, \
414-
StridedMemRefType<index_type, 1> *dim2lvlRef) { \
415-
assert(lvlCOO &&vref); \
416+
#define IMPL_FORWARDINGINSERT(VNAME, V) \
417+
void _mlir_ciface_forwardingInsert##VNAME( \
418+
void *t, StridedMemRefType<V, 0> *vref, \
419+
StridedMemRefType<index_type, 1> *dimCoordsRef) { \
420+
assert(t &&vref); \
416421
ASSERT_NO_STRIDE(dimCoordsRef); \
417-
ASSERT_NO_STRIDE(dim2lvlRef); \
418-
const uint64_t rank = MEMREF_GET_USIZE(dimCoordsRef); \
419-
ASSERT_USIZE_EQ(dim2lvlRef, rank); \
420422
const index_type *dimCoords = MEMREF_GET_PAYLOAD(dimCoordsRef); \
421-
const index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef); \
422-
std::vector<index_type> lvlCoords(rank); \
423-
for (uint64_t d = 0; d < rank; ++d) \
424-
lvlCoords[dim2lvl[d]] = dimCoords[d]; \
425-
V *value = MEMREF_GET_PAYLOAD(vref); \
426-
static_cast<SparseTensorCOO<V> *>(lvlCOO)->add(lvlCoords, *value); \
427-
return lvlCOO; \
423+
assert(dimCoords); \
424+
const V *value = MEMREF_GET_PAYLOAD(vref); \
425+
static_cast<SparseTensorStorageBase *>(t)->forwardingInsert(dimCoords, \
426+
*value); \
428427
}
429-
MLIR_SPARSETENSOR_FOREVERY_V(IMPL_ADDELT)
430-
#undef IMPL_ADDELT
428+
MLIR_SPARSETENSOR_FOREVERY_V(IMPL_FORWARDINGINSERT)
429+
#undef IMPL_FORWARDINGINSERT
431430

432431
// NOTE: the `cref` argument uses the same coordinate-space as the `iter`
433432
// (which can be either dim- or lvl-coords, depending on context).
@@ -692,8 +691,12 @@ index_type sparseDimSize(void *tensor, index_type d) {
692691
return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
693692
}
694693

695-
void endInsert(void *tensor) {
696-
return static_cast<SparseTensorStorageBase *>(tensor)->endInsert();
694+
void endForwardingInsert(void *tensor) {
695+
return static_cast<SparseTensorStorageBase *>(tensor)->endForwardingInsert();
696+
}
697+
698+
void endLexInsert(void *tensor) {
699+
return static_cast<SparseTensorStorageBase *>(tensor)->endLexInsert();
697700
}
698701

699702
#define IMPL_OUTSPARSETENSOR(VNAME, V) \

mlir/test/Dialect/SparseTensor/conversion.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ func.func @sparse_reconstruct(%arg0: tensor<128xf32, #SparseVector>) -> tensor<1
296296

297297
// CHECK-LABEL: func @sparse_reconstruct_ins(
298298
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>
299-
// CHECK: call @endInsert(%[[A]]) : (!llvm.ptr<i8>) -> ()
299+
// CHECK: call @endLexInsert(%[[A]]) : (!llvm.ptr<i8>) -> ()
300300
// CHECK: return %[[A]] : !llvm.ptr<i8>
301301
func.func @sparse_reconstruct_ins(%arg0: tensor<128xf32, #SparseVector>) -> tensor<128xf32, #SparseVector> {
302302
%0 = sparse_tensor.load %arg0 hasInserts : tensor<128xf32, #SparseVector>

mlir/test/Dialect/SparseTensor/sparse_expand.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
// CHECK-CONVERT: memref.dealloc %[[A]] : memref<?xf64>
6363
// CHECK-CONVERT: memref.dealloc %[[B]] : memref<?xi1>
6464
// CHECK-CONVERT: memref.dealloc %[[C]] : memref<?xindex>
65-
// CHECK-CONVERT: call @endInsert
65+
// CHECK-CONVERT: call @endLexInsert
6666
//
6767
func.func @kernel(%arga: tensor<?x?xf64, #DCSC>) -> tensor<?xf64, #SV> {
6868
%c0 = arith.constant 0 : index
@@ -115,7 +115,7 @@ func.func @kernel(%arga: tensor<?x?xf64, #DCSC>) -> tensor<?xf64, #SV> {
115115
// CHECK-CONVERT: memref.dealloc %[[A]] : memref<?xf64>
116116
// CHECK-CONVERT: memref.dealloc %[[B]] : memref<?xi1>
117117
// CHECK-CONVERT: memref.dealloc %[[C]] : memref<?xindex>
118-
// CHECK-CONVERT: call @endInsert
118+
// CHECK-CONVERT: call @endLexInsert
119119
//
120120
func.func @matmul1(%A: tensor<8x2xf64, #CSR>,
121121
%B: tensor<2x4xf64, #CSR>) -> tensor<8x4xf64, #CSR> {
@@ -163,7 +163,7 @@ func.func @matmul1(%A: tensor<8x2xf64, #CSR>,
163163
// CHECK-CONVERT: memref.dealloc %[[A]] : memref<?xf64>
164164
// CHECK-CONVERT: memref.dealloc %[[B]] : memref<?xi1>
165165
// CHECK-CONVERT: memref.dealloc %[[C]] : memref<?xindex>
166-
// CHECK-CONVERT: call @endInsert
166+
// CHECK-CONVERT: call @endLexInsert
167167
//
168168
func.func @matmul2(%A: tensor<8x2xf64, #CSC>,
169169
%B: tensor<2x4xf64, #CSC>) -> tensor<8x4xf64, #CSC> {

mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@
112112
// CHECK: memref.dealloc %[[VAL_20]] : memref<300xf64>
113113
// CHECK: memref.dealloc %[[VAL_22]] : memref<300xi1>
114114
// CHECK: memref.dealloc %[[VAL_24]] : memref<300xindex>
115-
// CHECK: call @endInsert(%[[VAL_19]]) : (!llvm.ptr<i8>) -> ()
115+
// CHECK: call @endLexInsert(%[[VAL_19]]) : (!llvm.ptr<i8>) -> ()
116116
// CHECK: return %[[VAL_19]] : !llvm.ptr<i8>
117117
// CHECK: }
118118
func.func @fill_zero_after_alloc(%arg0: tensor<100x200xf64, #DCSR>,

0 commit comments

Comments
 (0)