Skip to content

Commit e6005d5

Browse files
authored
[mlir][sparse] support 2:4 structured sparsity and loose compressed (#69968)
This adds library support for these two new level formats.
1 parent 83a6b02 commit e6005d5

File tree

5 files changed

+208
-56
lines changed

5 files changed

+208
-56
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ constexpr bool isCompressedDLT(DimLevelType dlt) {
277277
static_cast<uint8_t>(DimLevelType::Compressed);
278278
}
279279

280-
/// Check if the `DimLevelType` is compressed (regardless of properties).
280+
/// Check if the `DimLevelType` is loose compressed (regardless of properties).
281281
constexpr bool isLooseCompressedDLT(DimLevelType dlt) {
282282
return (static_cast<uint8_t>(dlt) & ~3) ==
283283
static_cast<uint8_t>(DimLevelType::LooseCompressed);
@@ -289,6 +289,12 @@ constexpr bool isSingletonDLT(DimLevelType dlt) {
289289
static_cast<uint8_t>(DimLevelType::Singleton);
290290
}
291291

292+
/// Check if the `DimLevelType` is 2OutOf4 (regardless of properties).
293+
constexpr bool is2OutOf4DLT(DimLevelType dlt) {
294+
return (static_cast<uint8_t>(dlt) & ~3) ==
295+
static_cast<uint8_t>(DimLevelType::TwoOutOfFour);
296+
}
297+
292298
/// Check if the `DimLevelType` is ordered (regardless of storage format).
293299
constexpr bool isOrderedDLT(DimLevelType dlt) {
294300
return !(static_cast<uint8_t>(dlt) & 2);

mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h

Lines changed: 65 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,19 @@ class SparseTensorStorageBase {
115115
return isCompressedDLT(getLvlType(l));
116116
}
117117

118+
/// Safely checks if the level uses loose compressed storage.
119+
bool isLooseCompressedLvl(uint64_t l) const {
120+
return isLooseCompressedDLT(getLvlType(l));
121+
}
122+
118123
/// Safely checks if the level uses singleton storage.
119124
bool isSingletonLvl(uint64_t l) const {
120125
return isSingletonDLT(getLvlType(l));
121126
}
122127

128+
/// Safely checks if the level uses 2 out of 4 storage.
129+
bool is2OutOf4Lvl(uint64_t l) const { return is2OutOf4DLT(getLvlType(l)); }
130+
123131
/// Safely checks if the level is ordered.
124132
bool isOrderedLvl(uint64_t l) const { return isOrderedDLT(getLvlType(l)); }
125133

@@ -138,9 +146,6 @@ class SparseTensorStorageBase {
138146
MLIR_SPARSETENSOR_FOREVERY_FIXED_O(DECL_GETCOORDINATES)
139147
#undef DECL_GETCOORDINATES
140148

141-
/// Gets the coordinate-value stored at the given level and position.
142-
virtual uint64_t getCrd(uint64_t lvl, uint64_t pos) const = 0;
143-
144149
/// Gets primary storage.
145150
#define DECL_GETVALUES(VNAME, V) virtual void getValues(std::vector<V> **);
146151
MLIR_SPARSETENSOR_FOREVERY_V(DECL_GETVALUES)
@@ -280,13 +285,6 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
280285
*out = &values;
281286
}
282287

283-
/// Returns coordinate at given position.
284-
uint64_t getCrd(uint64_t lvl, uint64_t pos) const final {
285-
assert(isCompressedDLT(getLvlType(lvl)) || isSingletonDLT(getLvlType(lvl)));
286-
assert(pos < coordinates[lvl].size());
287-
return coordinates[lvl][pos]; // Converts the stored `C` into `uint64_t`.
288-
}
289-
290288
/// Partially specialize forwarding insertions based on template types.
291289
void forwardingInsert(const uint64_t *dimCoords, V val) final {
292290
assert(dimCoords && coo);
@@ -302,7 +300,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
302300
if (allDense) {
303301
uint64_t lvlRank = getLvlRank();
304302
uint64_t valIdx = 0;
305-
// Linearize the address
303+
// Linearize the address.
306304
for (uint64_t lvl = 0; lvl < lvlRank; lvl++)
307305
valIdx = valIdx * getLvlSize(lvl) + lvlCoords[lvl];
308306
values[valIdx] = val;
@@ -441,16 +439,6 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
441439
}
442440

443441
private:
444-
/// Appends an arbitrary new position to `positions[lvl]`. This method
445-
/// checks that `pos` is representable in the `P` type; however, it
446-
/// does not check that `pos` is semantically valid (i.e., larger than
447-
/// the previous position and smaller than `coordinates[lvl].capacity()`).
448-
void appendPos(uint64_t lvl, uint64_t pos, uint64_t count = 1) {
449-
assert(isCompressedLvl(lvl));
450-
positions[lvl].insert(positions[lvl].end(), count,
451-
detail::checkOverflowCast<P>(pos));
452-
}
453-
454442
/// Appends coordinate `crd` to level `lvl`, in the semantically
455443
/// general sense. For non-dense levels, that means appending to the
456444
/// `coordinates[lvl]` array, checking that `crd` is representable in
@@ -461,11 +449,11 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
461449
/// `full` is the number of "entries" already written to `values` for this
462450
/// segment (aka one after the highest coordinate previously appended).
463451
void appendCrd(uint64_t lvl, uint64_t full, uint64_t crd) {
464-
const auto dlt = getLvlType(lvl); // Avoid redundant bounds checking.
465-
if (isCompressedDLT(dlt) || isSingletonDLT(dlt)) {
452+
if (!isDenseLvl(lvl)) {
453+
assert(isCompressedLvl(lvl) || isLooseCompressedLvl(lvl) ||
454+
isSingletonLvl(lvl) || is2OutOf4Lvl(lvl));
466455
coordinates[lvl].push_back(detail::checkOverflowCast<C>(crd));
467456
} else { // Dense level.
468-
assert(isDenseDLT(dlt));
469457
assert(crd >= full && "Coordinate was already filled");
470458
if (crd == full)
471459
return; // Short-circuit, since it'll be a nop.
@@ -482,15 +470,13 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
482470
/// storage, as opposed to "level-sizes" which are the cardinality
483471
/// of possible coordinates for that level.
484472
uint64_t assembledSize(uint64_t parentSz, uint64_t l) const {
485-
const auto dlt = getLvlType(l); // Avoid redundant bounds checking.
486-
if (isCompressedDLT(dlt))
473+
if (isCompressedLvl(l))
487474
return positions[l][parentSz];
488-
if (isSingletonDLT(dlt))
475+
if (isSingletonLvl(l))
489476
return parentSz; // New size is same as the parent.
490-
if (isDenseDLT(dlt))
491-
return parentSz * getLvlSize(l);
492-
MLIR_SPARSETENSOR_FATAL("unsupported level type: %d\n",
493-
static_cast<uint8_t>(dlt));
477+
// TODO: support levels assignment for loose/2:4?
478+
assert(isDenseLvl(l));
479+
return parentSz * getLvlSize(l);
494480
}
495481

496482
/// Initializes sparse tensor storage scheme from a memory-resident sparse
@@ -514,7 +500,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
514500
uint64_t seg = lo + 1;
515501
if (isUniqueLvl(l))
516502
while (seg < hi && lvlElements[seg].coords[l] == c)
517-
++seg;
503+
seg++;
518504
// Handle segment in interval for sparse or dense level.
519505
appendCrd(l, full, c);
520506
full = c + 1;
@@ -529,14 +515,22 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
529515
/// Finalizes the sparse position structure at this level.
530516
void finalizeSegment(uint64_t l, uint64_t full = 0, uint64_t count = 1) {
531517
if (count == 0)
532-
return; // Short-circuit, since it'll be a nop.
533-
const auto dlt = getLvlType(l); // Avoid redundant bounds checking.
534-
if (isCompressedDLT(dlt)) {
535-
appendPos(l, coordinates[l].size(), count);
536-
} else if (isSingletonDLT(dlt)) {
518+
return; // Short-circuit, since it'll be a nop.
519+
if (isCompressedLvl(l)) {
520+
uint64_t pos = coordinates[l].size();
521+
positions[l].insert(positions[l].end(), count,
522+
detail::checkOverflowCast<P>(pos));
523+
} else if (isLooseCompressedLvl(l)) {
524+
// Finish this level, and push pairs for the empty ones, and one
525+
// more for next level. Note that this always leaves one extra
526+
// unused element at the end.
527+
uint64_t pos = coordinates[l].size();
528+
positions[l].insert(positions[l].end(), 2 * count,
529+
detail::checkOverflowCast<P>(pos));
530+
} else if (isSingletonLvl(l) || is2OutOf4Lvl(l)) {
537531
return; // Nothing to finalize.
538532
} else { // Dense dimension.
539-
assert(isDenseDLT(dlt));
533+
assert(isDenseLvl(l));
540534
const uint64_t sz = getLvlSizes()[l];
541535
assert(sz >= full && "Segment is overfull");
542536
count = detail::checkedMul(count, sz - full);
@@ -589,7 +583,6 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
589583
(crd < cur && !isOrderedLvl(l))) {
590584
return l;
591585
}
592-
593586
if (crd < cur) {
594587
assert(false && "non-lexicographic insertion");
595588
return -1u;
@@ -609,27 +602,37 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
609602
return;
610603
}
611604
if (isCompressedLvl(l)) {
612-
// Look up the bounds of the `l`-level segment determined by the
613-
// `(l - 1)`-level position `parentPos`.
614605
const std::vector<P> &positionsL = positions[l];
615606
assert(parentPos + 1 < positionsL.size());
616607
const uint64_t pstart = static_cast<uint64_t>(positionsL[parentPos]);
617608
const uint64_t pstop = static_cast<uint64_t>(positionsL[parentPos + 1]);
618-
// Loop-invariant code for looking up the `l`-level coordinates.
619609
const std::vector<C> &coordinatesL = coordinates[l];
620610
assert(pstop <= coordinatesL.size());
621-
for (uint64_t pos = pstart; pos < pstop; ++pos) {
611+
for (uint64_t pos = pstart; pos < pstop; pos++) {
622612
lvlCursor[l] = static_cast<uint64_t>(coordinatesL[pos]);
623613
toCOO(pos, l + 1, dimCoords);
624614
}
625-
} else if (isSingletonLvl(l)) {
626-
lvlCursor[l] = getCrd(l, parentPos);
615+
} else if (isLooseCompressedLvl(l)) {
616+
const std::vector<P> &positionsL = positions[l];
617+
assert(2 * parentPos + 1 < positionsL.size());
618+
const uint64_t pstart = static_cast<uint64_t>(positionsL[2 * parentPos]);
619+
const uint64_t pstop =
620+
static_cast<uint64_t>(positionsL[2 * parentPos + 1]);
621+
const std::vector<C> &coordinatesL = coordinates[l];
622+
assert(pstop <= coordinatesL.size());
623+
for (uint64_t pos = pstart; pos < pstop; pos++) {
624+
lvlCursor[l] = static_cast<uint64_t>(coordinatesL[pos]);
625+
toCOO(pos, l + 1, dimCoords);
626+
}
627+
} else if (isSingletonLvl(l) || is2OutOf4Lvl(l)) {
628+
assert(parentPos < coordinates[l].size());
629+
lvlCursor[l] = static_cast<uint64_t>(coordinates[l][parentPos]);
627630
toCOO(parentPos, l + 1, dimCoords);
628631
} else { // Dense level.
629632
assert(isDenseLvl(l));
630633
const uint64_t sz = getLvlSizes()[l];
631634
const uint64_t pstart = parentPos * sz;
632-
for (uint64_t c = 0; c < sz; ++c) {
635+
for (uint64_t c = 0; c < sz; c++) {
633636
lvlCursor[l] = c;
634637
toCOO(pstart + c, l + 1, dimCoords);
635638
}
@@ -706,19 +709,30 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
706709
bool allDense = true;
707710
uint64_t sz = 1;
708711
for (uint64_t l = 0; l < lvlRank; l++) {
709-
const DimLevelType dlt = lvlTypes[l]; // Avoid redundant bounds checking.
710-
if (isCompressedDLT(dlt)) {
712+
if (isCompressedLvl(l)) {
711713
positions[l].reserve(sz + 1);
712714
positions[l].push_back(0);
713715
coordinates[l].reserve(sz);
714716
sz = 1;
715717
allDense = false;
716-
} else if (isSingletonDLT(dlt)) {
718+
} else if (isLooseCompressedLvl(l)) {
719+
positions[l].reserve(2 * sz + 1); // last one unused
720+
positions[l].push_back(0);
717721
coordinates[l].reserve(sz);
718722
sz = 1;
719723
allDense = false;
724+
} else if (isSingletonLvl(l)) {
725+
coordinates[l].reserve(sz);
726+
sz = 1;
727+
allDense = false;
728+
} else if (is2OutOf4Lvl(l)) {
729+
assert(allDense && l == lvlRank - 1 && "unexpected 2:4 usage");
730+
sz = detail::checkedMul(sz, lvlSizes[l]) / 2;
731+
coordinates[l].reserve(sz);
732+
values.reserve(sz);
733+
allDense = false;
720734
} else { // Dense level.
721-
assert(isDenseDLT(dlt));
735+
assert(isDenseLvl(l));
722736
sz = detail::checkedMul(sz, lvlSizes[l]);
723737
}
724738
}
@@ -773,6 +787,7 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
773787
positions[l].assign(posPtr, posPtr + parentSz + 1);
774788
coordinates[l].assign(crdPtr, crdPtr + positions[l][parentSz]);
775789
} else {
790+
// TODO: support levels assignment for loose/2:4?
776791
assert(isDenseLvl(l));
777792
}
778793
parentSz = assembledSize(parentSz, l);

mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,8 @@ SparseTensorStorageBase::SparseTensorStorageBase( // NOLINT
3636
assert(lvlRank > 0 && "Trivial shape is unsupported");
3737
for (uint64_t l = 0; l < lvlRank; ++l) {
3838
assert(lvlSizes[l] > 0 && "Level size zero has trivial storage");
39-
const auto dlt = lvlTypes[l];
40-
if (!(isDenseDLT(dlt) || isCompressedDLT(dlt) || isSingletonDLT(dlt))) {
41-
MLIR_SPARSETENSOR_FATAL("unsupported level type: %d\n",
42-
static_cast<uint8_t>(dlt));
43-
}
39+
assert(isDenseLvl(l) || isCompressedLvl(l) || isLooseCompressedLvl(l) ||
40+
isSingletonLvl(l) || is2OutOf4Lvl(l));
4441
}
4542
}
4643

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
//--------------------------------------------------------------------------------------------------
2+
// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS.
3+
//
4+
// Set-up that's shared across all tests in this directory. In principle, this
5+
// config could be moved to lit.local.cfg. However, there are downstream users that
6+
// do not use these LIT config files. Hence why this is kept inline.
7+
//
8+
// DEFINE: %{sparse_compiler_opts} = enable-runtime-library=true
9+
// DEFINE: %{sparse_compiler_opts_sve} = enable-arm-sve=true %{sparse_compiler_opts}
10+
// DEFINE: %{compile} = mlir-opt %s --sparse-compiler="%{sparse_compiler_opts}"
11+
// DEFINE: %{compile_sve} = mlir-opt %s --sparse-compiler="%{sparse_compiler_opts_sve}"
12+
// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
13+
// DEFINE: %{run_opts} = -e entry -entry-point-result=void
14+
// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
15+
// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
16+
//
17+
// DEFINE: %{env} =
18+
//--------------------------------------------------------------------------------------------------
19+
20+
// REDEFINE: %{env} = TENSOR0="%mlir_src_dir/test/Integration/data/ds.mtx"
21+
// RUN: %{compile} | env %{env} %{run} | FileCheck %s
22+
//
23+
// TODO: enable!
24+
// Do the same run, but now with direct IR generation.
25+
// REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=false
26+
// R_UN: %{compile} | env %{env} %{run} | FileCheck %s
27+
28+
!Filename = !llvm.ptr<i8>
29+
30+
#CSR = #sparse_tensor.encoding<{
31+
map = (i, j) -> ( i : dense, j : compressed)
32+
}>
33+
34+
#CSR_hi = #sparse_tensor.encoding<{
35+
map = (i, j) -> ( i : dense, j : loose_compressed)
36+
}>
37+
38+
#NV_24 = #sparse_tensor.encoding<{
39+
map = ( i, j ) -> ( i : dense,
40+
j floordiv 4 : dense,
41+
j mod 4 : block2_4),
42+
crdWidth = 8
43+
}>
44+
45+
module {
46+
47+
func.func private @getTensorFilename(index) -> (!Filename)
48+
49+
//
50+
// Input matrix:
51+
//
52+
// [[0.0, 0.0, 1.0, 2.0, 0.0, 3.0, 0.0, 4.0],
53+
// [0.0, 5.0, 6.0, 0.0, 7.0, 0.0, 0.0, 8.0],
54+
// [9.0, 0.0, 10.0, 0.0, 11.0, 12.0, 0.0, 0.0]]
55+
//
56+
func.func @entry() {
57+
%u0 = arith.constant 0 : i8
58+
%c0 = arith.constant 0 : index
59+
%f0 = arith.constant 0.0 : f64
60+
61+
%fileName = call @getTensorFilename(%c0) : (index) -> (!Filename)
62+
%A1 = sparse_tensor.new %fileName : !Filename to tensor<?x?xf64, #CSR>
63+
%A2 = sparse_tensor.new %fileName : !Filename to tensor<?x?xf64, #CSR_hi>
64+
%A3 = sparse_tensor.new %fileName : !Filename to tensor<?x?xf64, #NV_24>
65+
66+
//
67+
// CSR:
68+
//
69+
// CHECK: ( 0, 4, 8, 12 )
70+
// CHECK-NEXT: ( 2, 3, 5, 7, 1, 2, 4, 7, 0, 2, 4, 5 )
71+
// CHECK-NEXT: ( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 )
72+
//
73+
%pos1 = sparse_tensor.positions %A1 {level = 1 : index } : tensor<?x?xf64, #CSR> to memref<?xindex>
74+
%vecp1 = vector.transfer_read %pos1[%c0], %c0 : memref<?xindex>, vector<4xindex>
75+
vector.print %vecp1 : vector<4xindex>
76+
%crd1 = sparse_tensor.coordinates %A1 {level = 1 : index } : tensor<?x?xf64, #CSR> to memref<?xindex>
77+
%vecc1 = vector.transfer_read %crd1[%c0], %c0 : memref<?xindex>, vector<12xindex>
78+
vector.print %vecc1 : vector<12xindex>
79+
%val1 = sparse_tensor.values %A1 : tensor<?x?xf64, #CSR> to memref<?xf64>
80+
%vecv1 = vector.transfer_read %val1[%c0], %f0 : memref<?xf64>, vector<12xf64>
81+
vector.print %vecv1 : vector<12xf64>
82+
83+
//
84+
// CSR_hi:
85+
//
86+
// CHECK-NEXT: ( 0, 4, 4, 8, 8, 12 )
87+
// CHECK-NEXT: ( 2, 3, 5, 7, 1, 2, 4, 7, 0, 2, 4, 5 )
88+
// CHECK-NEXT: ( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 )
89+
//
90+
%pos2 = sparse_tensor.positions %A2 {level = 1 : index } : tensor<?x?xf64, #CSR_hi> to memref<?xindex>
91+
%vecp2 = vector.transfer_read %pos2[%c0], %c0 : memref<?xindex>, vector<6xindex>
92+
vector.print %vecp2 : vector<6xindex>
93+
%crd2 = sparse_tensor.coordinates %A2 {level = 1 : index } : tensor<?x?xf64, #CSR_hi> to memref<?xindex>
94+
%vecc2 = vector.transfer_read %crd2[%c0], %c0 : memref<?xindex>, vector<12xindex>
95+
vector.print %vecc2 : vector<12xindex>
96+
%val2 = sparse_tensor.values %A2 : tensor<?x?xf64, #CSR_hi> to memref<?xf64>
97+
%vecv2 = vector.transfer_read %val2[%c0], %f0 : memref<?xf64>, vector<12xf64>
98+
vector.print %vecv2 : vector<12xf64>
99+
100+
//
101+
// NV_24
102+
//
103+
// CHECK-NEXT: ( 2, 3, 1, 3, 1, 2, 0, 3, 0, 2, 0, 1 )
104+
// CHECK-NEXT: ( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 )
105+
//
106+
%crd3 = sparse_tensor.coordinates %A3 {level = 2 : index } : tensor<?x?xf64, #NV_24> to memref<?xi8>
107+
%vecc3 = vector.transfer_read %crd3[%c0], %u0 : memref<?xi8>, vector<12xi8>
108+
vector.print %vecc3 : vector<12xi8>
109+
%val3 = sparse_tensor.values %A3 : tensor<?x?xf64, #NV_24> to memref<?xf64>
110+
%vecv3 = vector.transfer_read %val3[%c0], %f0 : memref<?xf64>, vector<12xf64>
111+
vector.print %vecv3 : vector<12xf64>
112+
113+
// Release the resources.
114+
bufferization.dealloc_tensor %A1: tensor<?x?xf64, #CSR>
115+
bufferization.dealloc_tensor %A2: tensor<?x?xf64, #CSR_hi>
116+
bufferization.dealloc_tensor %A3: tensor<?x?xf64, #NV_24>
117+
118+
return
119+
}
120+
}

0 commit comments

Comments
 (0)