Skip to content

Commit aaf9164

Browse files
authored
Reapply "[mlir][sparse] remove LevelType enum, construct LevelType from LevelFormat and Properties" (llvm#81923) (llvm#81934)
1 parent 761113a commit aaf9164

File tree

8 files changed

+245
-357
lines changed

8 files changed

+245
-357
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h

Lines changed: 202 additions & 325 deletions
Large diffs are not rendered by default.

mlir/lib/CAPI/Dialect/SparseTensor.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ static_assert(
3434
"MlirSparseTensorLevelFormat (C-API) and LevelFormat (C++) mismatch");
3535

3636
static_assert(static_cast<int>(MLIR_SPARSE_PROPERTY_NON_ORDERED) ==
37-
static_cast<int>(LevelPropertyNondefault::Nonordered) &&
37+
static_cast<int>(LevelPropNonDefault::Nonordered) &&
3838
static_cast<int>(MLIR_SPARSE_PROPERTY_NON_UNIQUE) ==
39-
static_cast<int>(LevelPropertyNondefault::Nonunique),
39+
static_cast<int>(LevelPropNonDefault::Nonunique),
4040
"MlirSparseTensorLevelProperty (C-API) and "
4141
"LevelPropertyNondefault (C++) mismatch");
4242

@@ -80,7 +80,7 @@ enum MlirSparseTensorLevelFormat
8080
mlirSparseTensorEncodingAttrGetLvlFmt(MlirAttribute attr, intptr_t lvl) {
8181
LevelType lt =
8282
static_cast<LevelType>(mlirSparseTensorEncodingAttrGetLvlType(attr, lvl));
83-
return static_cast<MlirSparseTensorLevelFormat>(*getLevelFormat(lt));
83+
return static_cast<MlirSparseTensorLevelFormat>(lt.getLvlFmt());
8484
}
8585

8686
int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) {
@@ -96,9 +96,9 @@ MlirSparseTensorLevelType mlirSparseTensorEncodingAttrBuildLvlType(
9696
const enum MlirSparseTensorLevelPropertyNondefault *properties,
9797
unsigned size, unsigned n, unsigned m) {
9898

99-
std::vector<LevelPropertyNondefault> props;
99+
std::vector<LevelPropNonDefault> props;
100100
for (unsigned i = 0; i < size; i++)
101-
props.push_back(static_cast<LevelPropertyNondefault>(properties[i]));
101+
props.push_back(static_cast<LevelPropNonDefault>(properties[i]));
102102

103103
return static_cast<MlirSparseTensorLevelType>(
104104
*buildLevelType(static_cast<LevelFormat>(lvlFmt), props, n, m));

mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
8888
ERROR_IF(failed(parser.parseOptionalKeyword(&strVal)),
8989
"expected valid level property (e.g. nonordered, nonunique or high)")
9090
if (strVal.compare("nonunique") == 0) {
91-
*properties |= static_cast<uint64_t>(LevelPropertyNondefault::Nonunique);
91+
*properties |= static_cast<uint64_t>(LevelPropNonDefault::Nonunique);
9292
} else if (strVal.compare("nonordered") == 0) {
93-
*properties |= static_cast<uint64_t>(LevelPropertyNondefault::Nonordered);
93+
*properties |= static_cast<uint64_t>(LevelPropNonDefault::Nonordered);
9494
} else {
9595
parser.emitError(loc, "unknown level property: ") << strVal;
9696
return failure();

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,14 @@
3535
using namespace mlir;
3636
using namespace mlir::sparse_tensor;
3737

38+
// Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as
39+
// well.
40+
namespace mlir::sparse_tensor {
41+
llvm::hash_code hash_value(LevelType lt) {
42+
return llvm::hash_value(static_cast<uint64_t>(lt));
43+
}
44+
} // namespace mlir::sparse_tensor
45+
3846
//===----------------------------------------------------------------------===//
3947
// Local Convenience Methods.
4048
//===----------------------------------------------------------------------===//
@@ -83,11 +91,11 @@ void StorageLayout::foreachField(
8391
}
8492
// The values array.
8593
if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
86-
LevelType::Undef)))
94+
LevelFormat::Undef)))
8795
return;
8896
// Put metadata at the end.
8997
if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
90-
LevelType::Undef)))
98+
LevelFormat::Undef)))
9199
return;
92100
}
93101

@@ -341,7 +349,7 @@ Level SparseTensorEncodingAttr::getLvlRank() const {
341349

342350
LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
343351
if (!getImpl())
344-
return LevelType::Dense;
352+
return LevelFormat::Dense;
345353
assert(l < getLvlRank() && "Level is out of bounds");
346354
return getLvlTypes()[l];
347355
}
@@ -975,7 +983,7 @@ static SparseTensorEncodingAttr
975983
getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
976984
SmallVector<LevelType> lts;
977985
for (auto lt : enc.getLvlTypes())
978-
lts.push_back(*buildLevelType(*getLevelFormat(lt), true, true));
986+
lts.push_back(lt.stripProperties());
979987

980988
return SparseTensorEncodingAttr::get(
981989
enc.getContext(), lts,

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ static bool isZeroValue(Value val) {
4646
static bool isSparseTensor(Value v) {
4747
auto enc = getSparseTensorEncoding(v.getType());
4848
return enc && !llvm::all_of(enc.getLvlTypes(),
49-
[](auto lt) { return lt == LevelType::Dense; });
49+
[](auto lt) { return lt == LevelFormat::Dense; });
5050
}
5151
static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); }
5252

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class SparseLevel : public SparseTensorLevel {
6363
class DenseLevel : public SparseTensorLevel {
6464
public:
6565
DenseLevel(unsigned tid, Level lvl, Value lvlSize, bool encoded)
66-
: SparseTensorLevel(tid, lvl, LevelType::Dense, lvlSize),
66+
: SparseTensorLevel(tid, lvl, LevelFormat::Dense, lvlSize),
6767
encoded(encoded) {}
6868

6969
Value peekCrdAt(OpBuilder &, Location, Value pos) const override {
@@ -1275,7 +1275,7 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
12751275
Value sz = stt.hasEncoding() ? b.create<LvlOp>(l, t, lvl).getResult()
12761276
: b.create<tensor::DimOp>(l, t, lvl).getResult();
12771277

1278-
switch (*getLevelFormat(lt)) {
1278+
switch (lt.getLvlFmt()) {
12791279
case LevelFormat::Dense:
12801280
return std::make_unique<DenseLevel>(tid, lvl, sz, stt.hasEncoding());
12811281
case LevelFormat::Compressed: {
@@ -1296,6 +1296,8 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
12961296
Value crd = genToCoordinates(b, l, t, lvl);
12971297
return std::make_unique<NOutOfMLevel>(tid, lvl, lt, sz, crd);
12981298
}
1299+
case LevelFormat::Undef:
1300+
llvm_unreachable("undefined level format");
12991301
}
13001302
llvm_unreachable("unrecognizable level format");
13011303
}

mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,8 @@ Merger::Merger(unsigned numInputOutputTensors, unsigned numLoops,
226226
syntheticTensor(numInputOutputTensors),
227227
numTensors(numInputOutputTensors + 1), numLoops(numLoops),
228228
hasSparseOut(false),
229-
lvlTypes(numTensors, std::vector<LevelType>(numLoops, LevelType::Undef)),
229+
lvlTypes(numTensors,
230+
std::vector<LevelType>(numLoops, LevelFormat::Undef)),
230231
loopToLvl(numTensors,
231232
std::vector<std::optional<Level>>(numLoops, std::nullopt)),
232233
lvlToLoop(numTensors,

mlir/unittests/Dialect/SparseTensor/MergerTest.cpp

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -313,11 +313,11 @@ class MergerTest3T1L : public MergerTestBase {
313313
MergerTest3T1L() : MergerTestBase(3, 1) {
314314
EXPECT_TRUE(merger.getOutTensorID() == tid(2));
315315
// Tensor 0: sparse input vector.
316-
merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Compressed);
316+
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
317317
// Tensor 1: sparse input vector.
318-
merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Compressed);
318+
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed);
319319
// Tensor 2: dense output vector.
320-
merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Dense);
320+
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense);
321321
}
322322
};
323323

@@ -327,13 +327,13 @@ class MergerTest4T1L : public MergerTestBase {
327327
MergerTest4T1L() : MergerTestBase(4, 1) {
328328
EXPECT_TRUE(merger.getOutTensorID() == tid(3));
329329
// Tensor 0: sparse input vector.
330-
merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Compressed);
330+
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
331331
// Tensor 1: sparse input vector.
332-
merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Compressed);
332+
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed);
333333
// Tensor 2: sparse input vector
334-
merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Compressed);
334+
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed);
335335
// Tensor 3: dense output vector
336-
merger.setLevelAndType(tid(3), lid(0), 0, LevelType::Dense);
336+
merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense);
337337
}
338338
};
339339

@@ -347,11 +347,11 @@ class MergerTest3T1LD : public MergerTestBase {
347347
MergerTest3T1LD() : MergerTestBase(3, 1) {
348348
EXPECT_TRUE(merger.getOutTensorID() == tid(2));
349349
// Tensor 0: sparse input vector.
350-
merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Compressed);
350+
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
351351
// Tensor 1: dense input vector.
352-
merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Dense);
352+
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense);
353353
// Tensor 2: dense output vector.
354-
merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Dense);
354+
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense);
355355
}
356356
};
357357

@@ -365,13 +365,13 @@ class MergerTest4T1LU : public MergerTestBase {
365365
MergerTest4T1LU() : MergerTestBase(4, 1) {
366366
EXPECT_TRUE(merger.getOutTensorID() == tid(3));
367367
// Tensor 0: undef input vector.
368-
merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Undef);
368+
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef);
369369
// Tensor 1: dense input vector.
370-
merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Dense);
370+
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense);
371371
// Tensor 2: undef input vector.
372-
merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Undef);
372+
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Undef);
373373
// Tensor 3: dense output vector.
374-
merger.setLevelAndType(tid(3), lid(0), 0, LevelType::Dense);
374+
merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense);
375375
}
376376
};
377377

@@ -387,11 +387,11 @@ class MergerTest3T1LSo : public MergerTestBase {
387387
EXPECT_TRUE(merger.getSynTensorID() == tid(3));
388388
merger.setHasSparseOut(true);
389389
// Tensor 0: undef input vector.
390-
merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Undef);
390+
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef);
391391
// Tensor 1: undef input vector.
392-
merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Undef);
392+
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Undef);
393393
// Tensor 2: sparse output vector.
394-
merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Compressed);
394+
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed);
395395
}
396396
};
397397

0 commit comments

Comments
 (0)