Skip to content

Commit 513448d

Browse files
authored
Revert "[mlir][sparse] remove LevelType enum, construct LevelType from LevelF…" (#81923)
Reverts #81799 ; this broke the mlir gcc7 bot.
1 parent 2a9b86c commit 513448d

File tree

8 files changed

+357
-245
lines changed

8 files changed

+357
-245
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h

Lines changed: 325 additions & 202 deletions
Large diffs are not rendered by default.

mlir/lib/CAPI/Dialect/SparseTensor.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ static_assert(
3434
"MlirSparseTensorLevelFormat (C-API) and LevelFormat (C++) mismatch");
3535

3636
static_assert(static_cast<int>(MLIR_SPARSE_PROPERTY_NON_ORDERED) ==
37-
static_cast<int>(LevelPropNonDefault::Nonordered) &&
37+
static_cast<int>(LevelPropertyNondefault::Nonordered) &&
3838
static_cast<int>(MLIR_SPARSE_PROPERTY_NON_UNIQUE) ==
39-
static_cast<int>(LevelPropNonDefault::Nonunique),
39+
static_cast<int>(LevelPropertyNondefault::Nonunique),
4040
"MlirSparseTensorLevelProperty (C-API) and "
4141
"LevelPropertyNondefault (C++) mismatch");
4242

@@ -80,7 +80,7 @@ enum MlirSparseTensorLevelFormat
8080
mlirSparseTensorEncodingAttrGetLvlFmt(MlirAttribute attr, intptr_t lvl) {
8181
LevelType lt =
8282
static_cast<LevelType>(mlirSparseTensorEncodingAttrGetLvlType(attr, lvl));
83-
return static_cast<MlirSparseTensorLevelFormat>(lt.getLvlFmt());
83+
return static_cast<MlirSparseTensorLevelFormat>(*getLevelFormat(lt));
8484
}
8585

8686
int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) {
@@ -96,9 +96,9 @@ MlirSparseTensorLevelType mlirSparseTensorEncodingAttrBuildLvlType(
9696
const enum MlirSparseTensorLevelPropertyNondefault *properties,
9797
unsigned size, unsigned n, unsigned m) {
9898

99-
std::vector<LevelPropNonDefault> props;
99+
std::vector<LevelPropertyNondefault> props;
100100
for (unsigned i = 0; i < size; i++)
101-
props.push_back(static_cast<LevelPropNonDefault>(properties[i]));
101+
props.push_back(static_cast<LevelPropertyNondefault>(properties[i]));
102102

103103
return static_cast<MlirSparseTensorLevelType>(
104104
*buildLevelType(static_cast<LevelFormat>(lvlFmt), props, n, m));

mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
8888
ERROR_IF(failed(parser.parseOptionalKeyword(&strVal)),
8989
"expected valid level property (e.g. nonordered, nonunique or high)")
9090
if (strVal.compare("nonunique") == 0) {
91-
*properties |= static_cast<uint64_t>(LevelPropNonDefault::Nonunique);
91+
*properties |= static_cast<uint64_t>(LevelPropertyNondefault::Nonunique);
9292
} else if (strVal.compare("nonordered") == 0) {
93-
*properties |= static_cast<uint64_t>(LevelPropNonDefault::Nonordered);
93+
*properties |= static_cast<uint64_t>(LevelPropertyNondefault::Nonordered);
9494
} else {
9595
parser.emitError(loc, "unknown level property: ") << strVal;
9696
return failure();

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,6 @@
3535
using namespace mlir;
3636
using namespace mlir::sparse_tensor;
3737

38-
// Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as
39-
// well.
40-
namespace mlir::sparse_tensor {
41-
llvm::hash_code hash_value(LevelType lt) {
42-
return llvm::hash_value(static_cast<uint64_t>(lt));
43-
}
44-
} // namespace mlir::sparse_tensor
45-
4638
//===----------------------------------------------------------------------===//
4739
// Local Convenience Methods.
4840
//===----------------------------------------------------------------------===//
@@ -91,11 +83,11 @@ void StorageLayout::foreachField(
9183
}
9284
// The values array.
9385
if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
94-
LevelFormat::Undef)))
86+
LevelType::Undef)))
9587
return;
9688
// Put metadata at the end.
9789
if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
98-
LevelFormat::Undef)))
90+
LevelType::Undef)))
9991
return;
10092
}
10193

@@ -349,7 +341,7 @@ Level SparseTensorEncodingAttr::getLvlRank() const {
349341

350342
LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
351343
if (!getImpl())
352-
return LevelFormat::Dense;
344+
return LevelType::Dense;
353345
assert(l < getLvlRank() && "Level is out of bounds");
354346
return getLvlTypes()[l];
355347
}
@@ -983,7 +975,7 @@ static SparseTensorEncodingAttr
983975
getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
984976
SmallVector<LevelType> lts;
985977
for (auto lt : enc.getLvlTypes())
986-
lts.push_back(lt.stripProperties());
978+
lts.push_back(*buildLevelType(*getLevelFormat(lt), true, true));
987979

988980
return SparseTensorEncodingAttr::get(
989981
enc.getContext(), lts,

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ static bool isZeroValue(Value val) {
4646
static bool isSparseTensor(Value v) {
4747
auto enc = getSparseTensorEncoding(v.getType());
4848
return enc && !llvm::all_of(enc.getLvlTypes(),
49-
[](auto lt) { return lt == LevelFormat::Dense; });
49+
[](auto lt) { return lt == LevelType::Dense; });
5050
}
5151
static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); }
5252

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class SparseLevel : public SparseTensorLevel {
6363
class DenseLevel : public SparseTensorLevel {
6464
public:
6565
DenseLevel(unsigned tid, Level lvl, Value lvlSize, bool encoded)
66-
: SparseTensorLevel(tid, lvl, LevelFormat::Dense, lvlSize),
66+
: SparseTensorLevel(tid, lvl, LevelType::Dense, lvlSize),
6767
encoded(encoded) {}
6868

6969
Value peekCrdAt(OpBuilder &, Location, Value pos) const override {
@@ -1275,7 +1275,7 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
12751275
Value sz = stt.hasEncoding() ? b.create<LvlOp>(l, t, lvl).getResult()
12761276
: b.create<tensor::DimOp>(l, t, lvl).getResult();
12771277

1278-
switch (lt.getLvlFmt()) {
1278+
switch (*getLevelFormat(lt)) {
12791279
case LevelFormat::Dense:
12801280
return std::make_unique<DenseLevel>(tid, lvl, sz, stt.hasEncoding());
12811281
case LevelFormat::Compressed: {
@@ -1296,8 +1296,6 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
12961296
Value crd = genToCoordinates(b, l, t, lvl);
12971297
return std::make_unique<NOutOfMLevel>(tid, lvl, lt, sz, crd);
12981298
}
1299-
case LevelFormat::Undef:
1300-
llvm_unreachable("undefined level format");
13011299
}
13021300
llvm_unreachable("unrecognizable level format");
13031301
}

mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,7 @@ Merger::Merger(unsigned numInputOutputTensors, unsigned numLoops,
226226
syntheticTensor(numInputOutputTensors),
227227
numTensors(numInputOutputTensors + 1), numLoops(numLoops),
228228
hasSparseOut(false),
229-
lvlTypes(numTensors,
230-
std::vector<LevelType>(numLoops, LevelFormat::Undef)),
229+
lvlTypes(numTensors, std::vector<LevelType>(numLoops, LevelType::Undef)),
231230
loopToLvl(numTensors,
232231
std::vector<std::optional<Level>>(numLoops, std::nullopt)),
233232
lvlToLoop(numTensors,

mlir/unittests/Dialect/SparseTensor/MergerTest.cpp

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -313,11 +313,11 @@ class MergerTest3T1L : public MergerTestBase {
313313
MergerTest3T1L() : MergerTestBase(3, 1) {
314314
EXPECT_TRUE(merger.getOutTensorID() == tid(2));
315315
// Tensor 0: sparse input vector.
316-
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
316+
merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Compressed);
317317
// Tensor 1: sparse input vector.
318-
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed);
318+
merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Compressed);
319319
// Tensor 2: dense output vector.
320-
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense);
320+
merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Dense);
321321
}
322322
};
323323

@@ -327,13 +327,13 @@ class MergerTest4T1L : public MergerTestBase {
327327
MergerTest4T1L() : MergerTestBase(4, 1) {
328328
EXPECT_TRUE(merger.getOutTensorID() == tid(3));
329329
// Tensor 0: sparse input vector.
330-
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
330+
merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Compressed);
331331
// Tensor 1: sparse input vector.
332-
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed);
332+
merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Compressed);
333333
// Tensor 2: sparse input vector
334-
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed);
334+
merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Compressed);
335335
// Tensor 3: dense output vector
336-
merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense);
336+
merger.setLevelAndType(tid(3), lid(0), 0, LevelType::Dense);
337337
}
338338
};
339339

@@ -347,11 +347,11 @@ class MergerTest3T1LD : public MergerTestBase {
347347
MergerTest3T1LD() : MergerTestBase(3, 1) {
348348
EXPECT_TRUE(merger.getOutTensorID() == tid(2));
349349
// Tensor 0: sparse input vector.
350-
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
350+
merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Compressed);
351351
// Tensor 1: dense input vector.
352-
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense);
352+
merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Dense);
353353
// Tensor 2: dense output vector.
354-
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense);
354+
merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Dense);
355355
}
356356
};
357357

@@ -365,13 +365,13 @@ class MergerTest4T1LU : public MergerTestBase {
365365
MergerTest4T1LU() : MergerTestBase(4, 1) {
366366
EXPECT_TRUE(merger.getOutTensorID() == tid(3));
367367
// Tensor 0: undef input vector.
368-
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef);
368+
merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Undef);
369369
// Tensor 1: dense input vector.
370-
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense);
370+
merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Dense);
371371
// Tensor 2: undef input vector.
372-
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Undef);
372+
merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Undef);
373373
// Tensor 3: dense output vector.
374-
merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense);
374+
merger.setLevelAndType(tid(3), lid(0), 0, LevelType::Dense);
375375
}
376376
};
377377

@@ -387,11 +387,11 @@ class MergerTest3T1LSo : public MergerTestBase {
387387
EXPECT_TRUE(merger.getSynTensorID() == tid(3));
388388
merger.setHasSparseOut(true);
389389
// Tensor 0: undef input vector.
390-
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef);
390+
merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Undef);
391391
// Tensor 1: undef input vector.
392-
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Undef);
392+
merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Undef);
393393
// Tensor 2: sparse output vector.
394-
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed);
394+
merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Compressed);
395395
}
396396
};
397397

0 commit comments

Comments
 (0)