Skip to content

Reapply "[mlir][sparse] remove LevelType enum, construct LevelType from LevelFormat and Properties" (#81923) #81934

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
527 changes: 202 additions & 325 deletions mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions mlir/lib/CAPI/Dialect/SparseTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ static_assert(
"MlirSparseTensorLevelFormat (C-API) and LevelFormat (C++) mismatch");

static_assert(static_cast<int>(MLIR_SPARSE_PROPERTY_NON_ORDERED) ==
static_cast<int>(LevelPropertyNondefault::Nonordered) &&
static_cast<int>(LevelPropNonDefault::Nonordered) &&
static_cast<int>(MLIR_SPARSE_PROPERTY_NON_UNIQUE) ==
static_cast<int>(LevelPropertyNondefault::Nonunique),
static_cast<int>(LevelPropNonDefault::Nonunique),
"MlirSparseTensorLevelProperty (C-API) and "
"LevelPropertyNondefault (C++) mismatch");

Expand Down Expand Up @@ -80,7 +80,7 @@ enum MlirSparseTensorLevelFormat
mlirSparseTensorEncodingAttrGetLvlFmt(MlirAttribute attr, intptr_t lvl) {
LevelType lt =
static_cast<LevelType>(mlirSparseTensorEncodingAttrGetLvlType(attr, lvl));
return static_cast<MlirSparseTensorLevelFormat>(*getLevelFormat(lt));
return static_cast<MlirSparseTensorLevelFormat>(lt.getLvlFmt());
}

int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) {
Expand All @@ -96,9 +96,9 @@ MlirSparseTensorLevelType mlirSparseTensorEncodingAttrBuildLvlType(
const enum MlirSparseTensorLevelPropertyNondefault *properties,
unsigned size, unsigned n, unsigned m) {

std::vector<LevelPropertyNondefault> props;
std::vector<LevelPropNonDefault> props;
for (unsigned i = 0; i < size; i++)
props.push_back(static_cast<LevelPropertyNondefault>(properties[i]));
props.push_back(static_cast<LevelPropNonDefault>(properties[i]));

return static_cast<MlirSparseTensorLevelType>(
*buildLevelType(static_cast<LevelFormat>(lvlFmt), props, n, m));
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
ERROR_IF(failed(parser.parseOptionalKeyword(&strVal)),
"expected valid level property (e.g. nonordered, nonunique or high)")
if (strVal.compare("nonunique") == 0) {
*properties |= static_cast<uint64_t>(LevelPropertyNondefault::Nonunique);
*properties |= static_cast<uint64_t>(LevelPropNonDefault::Nonunique);
} else if (strVal.compare("nonordered") == 0) {
*properties |= static_cast<uint64_t>(LevelPropertyNondefault::Nonordered);
*properties |= static_cast<uint64_t>(LevelPropNonDefault::Nonordered);
} else {
parser.emitError(loc, "unknown level property: ") << strVal;
return failure();
Expand Down
16 changes: 12 additions & 4 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@
using namespace mlir;
using namespace mlir::sparse_tensor;

// Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as
// well.
namespace mlir::sparse_tensor {
llvm::hash_code hash_value(LevelType lt) {
return llvm::hash_value(static_cast<uint64_t>(lt));
}
} // namespace mlir::sparse_tensor

//===----------------------------------------------------------------------===//
// Local Convenience Methods.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -83,11 +91,11 @@ void StorageLayout::foreachField(
}
// The values array.
if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
LevelType::Undef)))
LevelFormat::Undef)))
return;
// Put metadata at the end.
if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
LevelType::Undef)))
LevelFormat::Undef)))
return;
}

Expand Down Expand Up @@ -341,7 +349,7 @@ Level SparseTensorEncodingAttr::getLvlRank() const {

LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
if (!getImpl())
return LevelType::Dense;
return LevelFormat::Dense;
assert(l < getLvlRank() && "Level is out of bounds");
return getLvlTypes()[l];
}
Expand Down Expand Up @@ -975,7 +983,7 @@ static SparseTensorEncodingAttr
getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
SmallVector<LevelType> lts;
for (auto lt : enc.getLvlTypes())
lts.push_back(*buildLevelType(*getLevelFormat(lt), true, true));
lts.push_back(lt.stripProperties());

return SparseTensorEncodingAttr::get(
enc.getContext(), lts,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ static bool isZeroValue(Value val) {
static bool isSparseTensor(Value v) {
auto enc = getSparseTensorEncoding(v.getType());
return enc && !llvm::all_of(enc.getLvlTypes(),
[](auto lt) { return lt == LevelType::Dense; });
[](auto lt) { return lt == LevelFormat::Dense; });
}
static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class SparseLevel : public SparseTensorLevel {
class DenseLevel : public SparseTensorLevel {
public:
DenseLevel(unsigned tid, Level lvl, Value lvlSize, bool encoded)
: SparseTensorLevel(tid, lvl, LevelType::Dense, lvlSize),
: SparseTensorLevel(tid, lvl, LevelFormat::Dense, lvlSize),
encoded(encoded) {}

Value peekCrdAt(OpBuilder &, Location, Value pos) const override {
Expand Down Expand Up @@ -1275,7 +1275,7 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
Value sz = stt.hasEncoding() ? b.create<LvlOp>(l, t, lvl).getResult()
: b.create<tensor::DimOp>(l, t, lvl).getResult();

switch (*getLevelFormat(lt)) {
switch (lt.getLvlFmt()) {
case LevelFormat::Dense:
return std::make_unique<DenseLevel>(tid, lvl, sz, stt.hasEncoding());
case LevelFormat::Compressed: {
Expand All @@ -1296,6 +1296,8 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
Value crd = genToCoordinates(b, l, t, lvl);
return std::make_unique<NOutOfMLevel>(tid, lvl, lt, sz, crd);
}
case LevelFormat::Undef:
llvm_unreachable("undefined level format");
}
llvm_unreachable("unrecognizable level format");
}
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ Merger::Merger(unsigned numInputOutputTensors, unsigned numLoops,
syntheticTensor(numInputOutputTensors),
numTensors(numInputOutputTensors + 1), numLoops(numLoops),
hasSparseOut(false),
lvlTypes(numTensors, std::vector<LevelType>(numLoops, LevelType::Undef)),
lvlTypes(numTensors,
std::vector<LevelType>(numLoops, LevelFormat::Undef)),
loopToLvl(numTensors,
std::vector<std::optional<Level>>(numLoops, std::nullopt)),
lvlToLoop(numTensors,
Expand Down
34 changes: 17 additions & 17 deletions mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,11 +313,11 @@ class MergerTest3T1L : public MergerTestBase {
MergerTest3T1L() : MergerTestBase(3, 1) {
EXPECT_TRUE(merger.getOutTensorID() == tid(2));
// Tensor 0: sparse input vector.
merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Compressed);
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
// Tensor 1: sparse input vector.
merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Compressed);
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed);
// Tensor 2: dense output vector.
merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Dense);
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense);
}
};

Expand All @@ -327,13 +327,13 @@ class MergerTest4T1L : public MergerTestBase {
MergerTest4T1L() : MergerTestBase(4, 1) {
EXPECT_TRUE(merger.getOutTensorID() == tid(3));
// Tensor 0: sparse input vector.
merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Compressed);
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
// Tensor 1: sparse input vector.
merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Compressed);
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed);
// Tensor 2: sparse input vector
merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Compressed);
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed);
// Tensor 3: dense output vector
merger.setLevelAndType(tid(3), lid(0), 0, LevelType::Dense);
merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense);
}
};

Expand All @@ -347,11 +347,11 @@ class MergerTest3T1LD : public MergerTestBase {
MergerTest3T1LD() : MergerTestBase(3, 1) {
EXPECT_TRUE(merger.getOutTensorID() == tid(2));
// Tensor 0: sparse input vector.
merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Compressed);
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
// Tensor 1: dense input vector.
merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Dense);
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense);
// Tensor 2: dense output vector.
merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Dense);
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense);
}
};

Expand All @@ -365,13 +365,13 @@ class MergerTest4T1LU : public MergerTestBase {
MergerTest4T1LU() : MergerTestBase(4, 1) {
EXPECT_TRUE(merger.getOutTensorID() == tid(3));
// Tensor 0: undef input vector.
merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Undef);
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef);
// Tensor 1: dense input vector.
merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Dense);
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense);
// Tensor 2: undef input vector.
merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Undef);
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Undef);
// Tensor 3: dense output vector.
merger.setLevelAndType(tid(3), lid(0), 0, LevelType::Dense);
merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense);
}
};

Expand All @@ -387,11 +387,11 @@ class MergerTest3T1LSo : public MergerTestBase {
EXPECT_TRUE(merger.getSynTensorID() == tid(3));
merger.setHasSparseOut(true);
// Tensor 0: undef input vector.
merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Undef);
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef);
// Tensor 1: undef input vector.
merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Undef);
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Undef);
// Tensor 2: sparse output vector.
merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Compressed);
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed);
}
};

Expand Down