Skip to content

Commit f740366

Browse files
authored
[mlir][sparse] support type conversion from SoA COO to memrefs. (#82398)
1 parent a9b5753 commit f740366

File tree

4 files changed

+99
-9
lines changed

4 files changed

+99
-9
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,9 +303,9 @@ struct LevelType {
303303
}
304304

305305
/// Check if the `LevelType` is in the `LevelFormat`.
306-
template <LevelFormat fmt>
306+
template <LevelFormat... fmt>
307307
constexpr bool isa() const {
308-
return getLvlFmt() == fmt;
308+
return (... || (getLvlFmt() == fmt)) || false;
309309
}
310310

311311
/// Check if the `LevelType` has the properties

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,18 @@
1818
namespace mlir {
1919
namespace sparse_tensor {
2020

21+
/// A simple structure that encodes a range of levels in the sparse tensors that
22+
/// forms a COO segment.
23+
struct COOSegment {
24+
std::pair<Level, Level> lvlRange; // [low, high)
25+
bool isSoA;
26+
27+
bool isSegmentStart(Level l) const { return l == lvlRange.first; }
28+
bool inSegment(Level l) const {
29+
return l >= lvlRange.first && l < lvlRange.second;
30+
}
31+
};
32+
2133
//===----------------------------------------------------------------------===//
2234
/// A wrapper around `RankedTensorType`, which has three goals:
2335
///
@@ -330,6 +342,9 @@ class SparseTensorType {
330342
/// Returns [un]ordered COO type for this sparse tensor type.
331343
RankedTensorType getCOOType(bool ordered) const;
332344

345+
/// Returns a list of COO segments in the sparse tensor types.
346+
SmallVector<COOSegment> getCOOSegments() const;
347+
333348
private:
334349
// These two must be const, to ensure coherence of the memoized fields.
335350
const RankedTensorType rtp;

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,12 @@ void StorageLayout::foreachField(
7474
callback) const {
7575
const auto lvlTypes = enc.getLvlTypes();
7676
const Level lvlRank = enc.getLvlRank();
77-
const Level cooStart = SparseTensorType(enc).getCOOStart();
78-
const Level end = cooStart == lvlRank ? cooStart : cooStart + 1;
77+
SmallVector<COOSegment> cooSegs = SparseTensorType(enc).getCOOSegments();
7978
FieldIndex fieldIdx = kDataFieldStartingIdx;
79+
80+
ArrayRef cooSegsRef = cooSegs;
8081
// Per-level storage.
81-
for (Level l = 0; l < end; l++) {
82+
for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) {
8283
const auto lt = lvlTypes[l];
8384
if (isWithPosLT(lt)) {
8485
if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt)))
@@ -88,6 +89,21 @@ void StorageLayout::foreachField(
8889
if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt)))
8990
return;
9091
}
92+
if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
93+
if (!cooSegsRef.front().isSoA) {
94+
// AoS COO, all singletons are fused into one memrefs. Skips the entire
95+
// COO segement.
96+
l = cooSegsRef.front().lvlRange.second;
97+
} else {
98+
// SoA COO, each singleton level has one memref.
99+
l++;
100+
}
101+
// Expire handled COO segment.
102+
cooSegsRef = cooSegsRef.drop_front();
103+
} else {
104+
// Non COO levels.
105+
l++;
106+
}
91107
}
92108
// The values array.
93109
if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
@@ -796,13 +812,46 @@ bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl,
796812
}
797813

798814
Level mlir::sparse_tensor::SparseTensorType::getCOOStart() const {
799-
if (hasEncoding() && lvlRank > 1)
800-
for (Level l = 0; l < lvlRank - 1; l++)
801-
if (isCOOType(l, /*isUnique=*/false))
802-
return l;
815+
SmallVector<COOSegment> coo = getCOOSegments();
816+
if (!coo.empty()) {
817+
assert(coo.size() == 1);
818+
return coo.front().lvlRange.first;
819+
}
803820
return lvlRank;
804821
}
805822

823+
SmallVector<COOSegment>
824+
mlir::sparse_tensor::SparseTensorType::getCOOSegments() const {
825+
SmallVector<COOSegment> ret;
826+
if (!hasEncoding() || lvlRank <= 1)
827+
return ret;
828+
829+
ArrayRef<LevelType> lts = getLvlTypes();
830+
Level l = 0;
831+
while (l < lvlRank) {
832+
auto lt = lts[l];
833+
if (lt.isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) {
834+
auto cur = lts.begin() + l;
835+
auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) {
836+
return !lt.isa<LevelFormat::Singleton>();
837+
});
838+
unsigned cooLen = std::distance(cur, end);
839+
if (cooLen > 1) {
840+
// To support mixed SoA/AoS COO, we should break the segment when the
841+
// storage scheme changes, for now we faithfully assume that all
842+
// consecutive singleton levels have the same storage format as verified
843+
// STEA.
844+
ret.push_back(COOSegment{std::make_pair(l, l + cooLen),
845+
lts[l + 1].isa<LevelPropNonDefault::SoA>()});
846+
}
847+
l += cooLen;
848+
} else {
849+
l++;
850+
}
851+
}
852+
return ret;
853+
}
854+
806855
RankedTensorType
807856
mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
808857
SmallVector<LevelType> lvlTypes;

mlir/test/Dialect/SparseTensor/codegen.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@
4848
map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton)
4949
}>
5050

51+
#SoACOO = #sparse_tensor.encoding<{
52+
map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa))
53+
}>
54+
5155
#CooPNo = #sparse_tensor.encoding<{
5256
map = (d0, d1) -> (d1 : compressed(nonunique), d0 : singleton(nonordered))
5357
}>
@@ -67,6 +71,28 @@ func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #Spa
6771
return %arg0 : tensor<?xf64, #SparseVector>
6872
}
6973

74+
// CHECK-LABEL: func @sparse_nop_aos_coo(
75+
// CHECK-SAME: %[[POS:.*0]]: memref<?xindex>,
76+
// CHECK-SAME: %[[AoS_CRD:.*1]]: memref<?xindex>,
77+
// CHECK-SAME: %[[VAL:.*]]: memref<?xf64>,
78+
// CHECK-SAME: %[[A3:.*]]: !sparse_tensor.storage_specifier
79+
// CHECK: return %[[POS]], %[[AoS_CRD]], %[[VAL]], %[[A3]]
80+
func.func @sparse_nop_aos_coo(%arg0: tensor<?x?xf64, #Coo>) -> tensor<?x?xf64, #Coo> {
81+
return %arg0 : tensor<?x?xf64, #Coo>
82+
}
83+
84+
// CHECK-LABEL: func @sparse_nop_soa_coo(
85+
// CHECK-SAME: %[[POS:.*0]]: memref<?xindex>,
86+
// CHECK-SAME: %[[SoA_CRD_0:.*1]]: memref<?xindex>,
87+
// CHECK-SAME: %[[SoA_CRD_1:.*2]]: memref<?xindex>,
88+
// CHECK-SAME: %[[VAL:.*]]: memref<?xf64>,
89+
// CHECK-SAME: %[[A3:.*]]: !sparse_tensor.storage_specifier
90+
// CHECK: return %[[POS]], %[[SoA_CRD_0]], %[[SoA_CRD_1]], %[[VAL]], %[[A3]]
91+
func.func @sparse_nop_soa_coo(%arg0: tensor<?x?xf64, #SoACOO>) -> tensor<?x?xf64, #SoACOO> {
92+
return %arg0 : tensor<?x?xf64, #SoACOO>
93+
}
94+
95+
7096
// CHECK-LABEL: func @sparse_nop_multi_ret(
7197
// CHECK-SAME: %[[A0:.*0]]: memref<?xi32>,
7298
// CHECK-SAME: %[[A1:.*1]]: memref<?xi64>,

0 commit comments

Comments
 (0)