Skip to content

Commit 5248a98

Browse files
authored
[mlir][sparse] support SoA COO in codegen path. (#82439)
*NOTE*: the `SoA` property only makes a difference on codegen path, and is ignored in libgen path at the moment (only SoA COO is supported).
1 parent 4ca0480 commit 5248a98

File tree

9 files changed

+43
-30
lines changed

9 files changed

+43
-30
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,13 @@ struct LevelType {
283283
}
284284
bool operator!=(const LevelType lhs) const { return !(*this == lhs); }
285285

286-
LevelType stripProperties() const { return LevelType(lvlBits & ~0xffff); }
286+
LevelType stripStorageIrrelevantProperties() const {
287+
// Properties other than `SoA` do not change the storage scheme of the
288+
// sparse tensor.
289+
constexpr uint64_t mask =
290+
0xffff & ~static_cast<uint64_t>(LevelPropNonDefault::SoA);
291+
return LevelType(lvlBits & ~mask);
292+
}
287293

288294
/// Get N of NOutOfM level type.
289295
constexpr uint64_t getN() const {

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ struct COOSegment {
2424
std::pair<Level, Level> lvlRange; // [low, high)
2525
bool isSoA;
2626

27+
bool isAoS() const { return !isSoA; }
2728
bool isSegmentStart(Level l) const { return l == lvlRange.first; }
2829
bool inSegment(Level l) const {
2930
return l >= lvlRange.first && l < lvlRange.second;
@@ -337,7 +338,9 @@ class SparseTensorType {
337338
/// Returns the starting level of this sparse tensor type for a
338339
/// trailing COO region that spans **at least** two levels. If
339340
/// no such COO region is found, then returns the level-rank.
340-
Level getCOOStart() const;
341+
///
342+
/// DEPRECATED: use getCOOSegment instead;
343+
Level getAoSCOOStart() const;
341344

342345
/// Returns [un]ordered COO type for this sparse tensor type.
343346
RankedTensorType getCOOType(bool ordered) const;

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
182182
unsigned stride = 1;
183183
if (kind == SparseTensorFieldKind::CrdMemRef) {
184184
assert(lvl.has_value());
185-
const Level cooStart = SparseTensorType(enc).getCOOStart();
185+
const Level cooStart = SparseTensorType(enc).getAoSCOOStart();
186186
const Level lvlRank = enc.getLvlRank();
187187
if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
188188
lvl = cooStart;
@@ -811,10 +811,10 @@ bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl,
811811
return !isUnique || isUniqueLvl(lvlRank - 1);
812812
}
813813

814-
Level mlir::sparse_tensor::SparseTensorType::getCOOStart() const {
814+
Level mlir::sparse_tensor::SparseTensorType::getAoSCOOStart() const {
815815
SmallVector<COOSegment> coo = getCOOSegments();
816-
if (!coo.empty()) {
817-
assert(coo.size() == 1);
816+
assert(coo.size() == 1 || coo.empty());
817+
if (!coo.empty() && coo.front().isAoS()) {
818818
return coo.front().lvlRange.first;
819819
}
820820
return lvlRank;
@@ -1051,7 +1051,7 @@ static SparseTensorEncodingAttr
10511051
getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
10521052
SmallVector<LevelType> lts;
10531053
for (auto lt : enc.getLvlTypes())
1054-
lts.push_back(lt.stripProperties());
1054+
lts.push_back(lt.stripStorageIrrelevantProperties());
10551055

10561056
return SparseTensorEncodingAttr::get(
10571057
enc.getContext(), lts,
@@ -1137,7 +1137,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
11371137
return op->emitError("the sparse-tensor must have an encoding attribute");
11381138

11391139
// Verifies the trailing COO.
1140-
Level cooStartLvl = stt.getCOOStart();
1140+
Level cooStartLvl = stt.getAoSCOOStart();
11411141
if (cooStartLvl < stt.getLvlRank()) {
11421142
// We only supports trailing COO for now, must be the last input.
11431143
auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
@@ -1452,7 +1452,7 @@ LogicalResult ToCoordinatesOp::verify() {
14521452

14531453
LogicalResult ToCoordinatesBufferOp::verify() {
14541454
auto stt = getSparseTensorType(getTensor());
1455-
if (stt.getCOOStart() >= stt.getLvlRank())
1455+
if (stt.getAoSCOOStart() >= stt.getLvlRank())
14561456
return emitError("expected sparse tensor with a COO region");
14571457
return success();
14581458
}

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ static void createAllocFields(OpBuilder &builder, Location loc,
194194
valHeuristic =
195195
builder.create<arith::MulIOp>(loc, valHeuristic, lvlSizesValues[lvl]);
196196
} else if (sizeHint) {
197-
if (stt.getCOOStart() == 0) {
197+
if (stt.getAoSCOOStart() == 0) {
198198
posHeuristic = constantIndex(builder, loc, 2);
199199
crdHeuristic = builder.create<arith::MulIOp>(
200200
loc, constantIndex(builder, loc, lvlRank), sizeHint); // AOS
@@ -1316,7 +1316,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
13161316
Value posBack = c0; // index to the last value in the position array
13171317
Value memSize = c1; // memory size for current array
13181318

1319-
Level trailCOOStart = stt.getCOOStart();
1319+
Level trailCOOStart = stt.getAoSCOOStart();
13201320
Level trailCOORank = stt.getLvlRank() - trailCOOStart;
13211321
// Sets up SparseTensorSpecifier.
13221322
for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
@@ -1453,7 +1453,7 @@ struct SparseNewConverter : public OpConversionPattern<NewOp> {
14531453
const auto dstTp = getSparseTensorType(op.getResult());
14541454
// Creating COO with NewOp is handled by direct IR codegen. All other cases
14551455
// are handled by rewriting.
1456-
if (!dstTp.hasEncoding() || dstTp.getCOOStart() != 0)
1456+
if (!dstTp.hasEncoding() || dstTp.getAoSCOOStart() != 0)
14571457
return failure();
14581458

14591459
// Implement as follows:

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1180,7 +1180,7 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
11801180
PatternRewriter &rewriter) const override {
11811181
Location loc = op.getLoc();
11821182
auto stt = getSparseTensorType(op.getResult());
1183-
if (!stt.hasEncoding() || stt.getCOOStart() == 0)
1183+
if (!stt.hasEncoding() || stt.getAoSCOOStart() == 0)
11841184
return failure();
11851185

11861186
// Implement the NewOp as follows:

mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ Value sparse_tensor::genToCoordinates(OpBuilder &builder, Location loc,
568568
const auto srcTp = getSparseTensorType(tensor);
569569
const Type crdTp = srcTp.getCrdType();
570570
const Type memTp =
571-
get1DMemRefType(crdTp, /*withLayout=*/lvl >= srcTp.getCOOStart());
571+
get1DMemRefType(crdTp, /*withLayout=*/lvl >= srcTp.getAoSCOOStart());
572572
return builder.create<ToCoordinatesOp>(loc, memTp, tensor,
573573
builder.getIndexAttr(lvl));
574574
}

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ void SparseTensorSpecifier::setSpecifierField(OpBuilder &builder, Location loc,
103103

104104
Value sparse_tensor::SparseTensorDescriptor::getCrdMemRefOrView(
105105
OpBuilder &builder, Location loc, Level lvl) const {
106-
const Level cooStart = rType.getCOOStart();
106+
const Level cooStart = rType.getAoSCOOStart();
107107
if (lvl < cooStart)
108108
return getMemRefField(SparseTensorFieldKind::CrdMemRef, lvl);
109109

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ class SparseTensorDescriptorImpl {
137137
}
138138

139139
Value getAOSMemRef() const {
140-
const Level cooStart = rType.getCOOStart();
140+
const Level cooStart = rType.getAoSCOOStart();
141141
assert(cooStart < rType.getLvlRank());
142142
return getMemRefField(SparseTensorFieldKind::CrdMemRef, cooStart);
143143
}

mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_coo_test.mlir

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@
3434
map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton)
3535
}>
3636

37+
#SortedCOOSoA = #sparse_tensor.encoding<{
38+
map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa))
39+
}>
40+
3741
#CSR = #sparse_tensor.encoding<{
3842
map = (d0, d1) -> (d0 : dense, d1 : compressed)
3943
}>
@@ -50,7 +54,7 @@
5054

5155
module {
5256
func.func @add_coo_csr(%arga: tensor<8x8xf32, #CSR>,
53-
%argb: tensor<8x8xf32, #SortedCOO>)
57+
%argb: tensor<8x8xf32, #SortedCOOSoA>)
5458
-> tensor<8x8xf32> {
5559
%empty = tensor.empty() : tensor<8x8xf32>
5660
%zero = arith.constant 0.000000e+00 : f32
@@ -59,7 +63,7 @@ module {
5963
outs(%empty : tensor<8x8xf32>) -> tensor<8x8xf32>
6064
%0 = linalg.generic #trait
6165
ins(%arga, %argb: tensor<8x8xf32, #CSR>,
62-
tensor<8x8xf32, #SortedCOO>)
66+
tensor<8x8xf32, #SortedCOOSoA>)
6367
outs(%init: tensor<8x8xf32>) {
6468
^bb(%a: f32, %b: f32, %x: f32):
6569
%0 = arith.addf %a, %b : f32
@@ -69,7 +73,7 @@ module {
6973
}
7074

7175
func.func @add_coo_coo(%arga: tensor<8x8xf32, #SortedCOO>,
72-
%argb: tensor<8x8xf32, #SortedCOO>)
76+
%argb: tensor<8x8xf32, #SortedCOOSoA>)
7377
-> tensor<8x8xf32> {
7478
%empty = tensor.empty() : tensor<8x8xf32>
7579
%zero = arith.constant 0.000000e+00 : f32
@@ -78,7 +82,7 @@ module {
7882
outs(%empty : tensor<8x8xf32>) -> tensor<8x8xf32>
7983
%0 = linalg.generic #trait
8084
ins(%arga, %argb: tensor<8x8xf32, #SortedCOO>,
81-
tensor<8x8xf32, #SortedCOO>)
85+
tensor<8x8xf32, #SortedCOOSoA>)
8286
outs(%init: tensor<8x8xf32>) {
8387
^bb(%a: f32, %b: f32, %x: f32):
8488
%0 = arith.addf %a, %b : f32
@@ -88,12 +92,12 @@ module {
8892
}
8993

9094
func.func @add_coo_coo_out_coo(%arga: tensor<8x8xf32, #SortedCOO>,
91-
%argb: tensor<8x8xf32, #SortedCOO>)
95+
%argb: tensor<8x8xf32, #SortedCOOSoA>)
9296
-> tensor<8x8xf32, #SortedCOO> {
9397
%init = tensor.empty() : tensor<8x8xf32, #SortedCOO>
9498
%0 = linalg.generic #trait
9599
ins(%arga, %argb: tensor<8x8xf32, #SortedCOO>,
96-
tensor<8x8xf32, #SortedCOO>)
100+
tensor<8x8xf32, #SortedCOOSoA>)
97101
outs(%init: tensor<8x8xf32, #SortedCOO>) {
98102
^bb(%a: f32, %b: f32, %x: f32):
99103
%0 = arith.addf %a, %b : f32
@@ -104,7 +108,7 @@ module {
104108

105109

106110
func.func @add_coo_dense(%arga: tensor<8x8xf32>,
107-
%argb: tensor<8x8xf32, #SortedCOO>)
111+
%argb: tensor<8x8xf32, #SortedCOOSoA>)
108112
-> tensor<8x8xf32> {
109113
%empty = tensor.empty() : tensor<8x8xf32>
110114
%zero = arith.constant 0.000000e+00 : f32
@@ -113,7 +117,7 @@ module {
113117
outs(%empty : tensor<8x8xf32>) -> tensor<8x8xf32>
114118
%0 = linalg.generic #trait
115119
ins(%arga, %argb: tensor<8x8xf32>,
116-
tensor<8x8xf32, #SortedCOO>)
120+
tensor<8x8xf32, #SortedCOOSoA>)
117121
outs(%init: tensor<8x8xf32>) {
118122
^bb(%a: f32, %b: f32, %x: f32):
119123
%0 = arith.addf %a, %b : f32
@@ -154,19 +158,19 @@ module {
154158
%COO_A = sparse_tensor.convert %A
155159
: tensor<8x8xf32> to tensor<8x8xf32, #SortedCOO>
156160
%COO_B = sparse_tensor.convert %B
157-
: tensor<8x8xf32> to tensor<8x8xf32, #SortedCOO>
161+
: tensor<8x8xf32> to tensor<8x8xf32, #SortedCOOSoA>
158162

159163
%C1 = call @add_coo_dense(%A, %COO_B) : (tensor<8x8xf32>,
160-
tensor<8x8xf32, #SortedCOO>)
164+
tensor<8x8xf32, #SortedCOOSoA>)
161165
-> tensor<8x8xf32>
162166
%C2 = call @add_coo_csr(%CSR_A, %COO_B) : (tensor<8x8xf32, #CSR>,
163-
tensor<8x8xf32, #SortedCOO>)
167+
tensor<8x8xf32, #SortedCOOSoA>)
164168
-> tensor<8x8xf32>
165169
%C3 = call @add_coo_coo(%COO_A, %COO_B) : (tensor<8x8xf32, #SortedCOO>,
166-
tensor<8x8xf32, #SortedCOO>)
170+
tensor<8x8xf32, #SortedCOOSoA>)
167171
-> tensor<8x8xf32>
168172
%COO_RET = call @add_coo_coo_out_coo(%COO_A, %COO_B) : (tensor<8x8xf32, #SortedCOO>,
169-
tensor<8x8xf32, #SortedCOO>)
173+
tensor<8x8xf32, #SortedCOOSoA>)
170174
-> tensor<8x8xf32, #SortedCOO>
171175
%C4 = sparse_tensor.convert %COO_RET : tensor<8x8xf32, #SortedCOO> to tensor<8x8xf32>
172176
//
@@ -204,7 +208,7 @@ module {
204208
bufferization.dealloc_tensor %C4 : tensor<8x8xf32>
205209
bufferization.dealloc_tensor %CSR_A : tensor<8x8xf32, #CSR>
206210
bufferization.dealloc_tensor %COO_A : tensor<8x8xf32, #SortedCOO>
207-
bufferization.dealloc_tensor %COO_B : tensor<8x8xf32, #SortedCOO>
211+
bufferization.dealloc_tensor %COO_B : tensor<8x8xf32, #SortedCOOSoA>
208212
bufferization.dealloc_tensor %COO_RET : tensor<8x8xf32, #SortedCOO>
209213

210214

0 commit comments

Comments
 (0)