Skip to content

Commit 755285f

Browse files
committed
[mlir][sparse] Factoring out LoopEmitter::isValidLevel
Depends On D146674 Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D146676
1 parent d0acc6f commit 755285f

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ Operation *LoopEmitter::enterLoopOverTensorAtLvl(
456456
for (auto [t, l] : llvm::zip(tids, lvls)) {
457457
// TODO: this check for validity of the (t,l) pairs should be
458458
// checked/enforced at the callsites, if possible.
459-
assert(t < lvlTypes.size() && l < lvlTypes[t].size());
459+
assert(isValidLevel(t, l));
460460
assert(!coords[t][l]); // We cannot re-enter the same level
461461
const auto lvlTp = lvlTypes[t][l];
462462
const bool isSparse = isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp);
@@ -572,7 +572,7 @@ Operation *LoopEmitter::enterLoopOverTensorAtLvl(
572572
Operation *LoopEmitter::enterFilterLoopOverTensorAtLvl(
573573
OpBuilder &builder, Location loc, TensorId tid, Level lvl,
574574
AffineExpr affine, MutableArrayRef<Value> reduc) {
575-
assert(tid < lvlTypes.size() && lvl < lvlTypes[tid].size());
575+
assert(isValidLevel(tid, lvl));
576576
assert(!affine.isa<AffineDimExpr>() && !isDenseDLT(lvlTypes[tid][lvl]));
577577
// We can not re-enter the same level.
578578
assert(!coords[tid][lvl]);
@@ -862,7 +862,7 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
862862

863863
void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
864864
TensorId tid, Level dstLvl) {
865-
assert(tid < lvlTypes.size() && dstLvl < lvlTypes[tid].size());
865+
assert(isValidLevel(tid, dstLvl));
866866
const auto lvlTp = lvlTypes[tid][dstLvl];
867867

868868
if (isDenseDLT(lvlTp))

mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,10 @@ class LoopEmitter {
265265
return isOutputTensor(tid) && isSparseOut;
266266
}
267267

268+
bool isValidLevel(TensorId tid, Level lvl) const {
269+
return tid < lvlTypes.size() && lvl < lvlTypes[tid].size();
270+
}
271+
268272
/// Prepares loop for iterating over `tensor[lvl]`, under the assumption
269273
/// that `tensor[0...lvl-1]` loops have already been set up.
270274
void prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,

0 commit comments

Comments
 (0)