Skip to content

Commit d213220

Browse files
authored
[mlir][sparse] fixed naming consistency (#73053)
All DLT related methods have DLT at end, removed stale TODO
1 parent a756a6b commit d213220

File tree

4 files changed

+10
-12
lines changed

4 files changed

+10
-12
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -282,12 +282,12 @@ constexpr bool is2OutOf4DLT(DimLevelType dlt) {
282282
}
283283

284284
/// Check if the `DimLevelType` needs positions array.
285-
constexpr bool isDLTWithPos(DimLevelType dlt) {
285+
constexpr bool isWithPosDLT(DimLevelType dlt) {
286286
return isCompressedDLT(dlt) || isLooseCompressedDLT(dlt);
287287
}
288288

289289
/// Check if the `DimLevelType` needs coordinates array.
290-
constexpr bool isDLTWithCrd(DimLevelType dlt) {
290+
constexpr bool isWithCrdDLT(DimLevelType dlt) {
291291
return isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
292292
isLooseCompressedDLT(dlt) || is2OutOf4DLT(dlt);
293293
}
@@ -311,10 +311,8 @@ constexpr std::optional<LevelFormat> getLevelFormat(DimLevelType dlt) {
311311
}
312312

313313
/// Convert a LevelFormat to its corresponding DimLevelType with the given
314-
/// properties. Returns std::nullopt when the properties are not applicable for
315-
/// the input level format.
316-
/// TODO: factor out a new LevelProperties type so we can add new properties
317-
/// without changing this function's signature
314+
/// properties. Returns std::nullopt when the properties are not applicable
315+
/// for the input level format.
318316
constexpr std::optional<DimLevelType>
319317
buildLevelType(LevelFormat lf, bool ordered, bool unique) {
320318
auto dlt = static_cast<DimLevelType>(static_cast<uint8_t>(lf) |

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,8 @@ class SparseTensorType {
302302
bool is2OutOf4Lvl(Level l) const { return is2OutOf4DLT(getLvlType(l)); }
303303
bool isOrderedLvl(Level l) const { return isOrderedDLT(getLvlType(l)); }
304304
bool isUniqueLvl(Level l) const { return isUniqueDLT(getLvlType(l)); }
305-
bool isWithPos(Level l) const { return isDLTWithPos(getLvlType(l)); }
306-
bool isWithCrd(Level l) const { return isDLTWithCrd(getLvlType(l)); }
305+
bool isWithPos(Level l) const { return isWithPosDLT(getLvlType(l)); }
306+
bool isWithCrd(Level l) const { return isWithCrdDLT(getLvlType(l)); }
307307

308308
/// Returns the coordinate-overhead bitwidth, defaulting to zero.
309309
unsigned getCrdWidth() const { return enc ? enc.getCrdWidth() : 0; }

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,11 @@ void StorageLayout::foreachField(
7272
// Per-level storage.
7373
for (Level l = 0; l < end; l++) {
7474
const auto dlt = lvlTypes[l];
75-
if (isDLTWithPos(dlt)) {
75+
if (isWithPosDLT(dlt)) {
7676
if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, dlt)))
7777
return;
7878
}
79-
if (isDLTWithCrd(dlt)) {
79+
if (isWithCrdDLT(dlt)) {
8080
if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, dlt)))
8181
return;
8282
}

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1341,7 +1341,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
13411341
continue;
13421342
}
13431343

1344-
if (isDLTWithPos(dlt)) {
1344+
if (isWithPosDLT(dlt)) {
13451345
assert(isCompressedDLT(dlt) || isLooseCompressedDLT(dlt));
13461346
if (isLooseCompressedDLT(dlt)) {
13471347
memSize = rewriter.create<arith::MulIOp>(loc, memSize, c2);
@@ -1356,7 +1356,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
13561356
memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), posBack);
13571357
posBack = rewriter.create<arith::SubIOp>(loc, posBack, c1);
13581358
}
1359-
assert(isDLTWithCrd(dlt) && lvl <= trailCOOStart);
1359+
assert(isWithCrdDLT(dlt) && lvl <= trailCOOStart);
13601360
// FIXME: This seems to be unnecessarily complex, can we simplify it?
13611361
if (lvl == trailCOOStart) {
13621362
Value cooSz = rewriter.create<arith::MulIOp>(

0 commit comments

Comments
 (0)