Skip to content

[mlir][sparse] use a common util function to query the tensor level s… #76764

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 2, 2024

Conversation

PeimingLiu
Copy link
Member

…et in a lattice point.

@llvmbot llvmbot added mlir:sparse Sparse compiler in MLIR mlir labels Jan 2, 2024
@llvmbot
Copy link
Member

llvmbot commented Jan 2, 2024

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Peiming Liu (PeimingLiu)

Changes

…et in a lattice point.


Full diff: https://github.com/llvm/llvm-project/pull/76764.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp (+86-94)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 934e1e559f44d6..7be2f30d26d8ba 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -949,94 +949,9 @@ static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
 // Sparsifier synthesis methods (loop sequence).
 //===----------------------------------------------------------------------===//
 
-/// Starts a loop sequence at given level. Returns true if
-/// the universal loop index must be maintained at this level.
-static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
-                         LoopId curr, LatSetId lts) {
-  assert(!env.getLoopVar(curr));
-  // Emit invariants at this loop sequence level.
-  genInvariants(env, builder, exp, curr, /*isStart=*/true);
-  // Emit access pattern expansion for sparse tensor output.
-  genExpand(env, builder, curr, /*isStart=*/true);
-  // Emit further intitialization at this loop sequence level.
-  const LatPointId l0 = env.set(lts)[0];
-  bool needsUniv = false;
-
-  SmallVector<TensorLevel> tidLvls;
-  env.merger().foreachTensorLoopId(l0, [&](TensorLoopId b, TensorId tid,
-                                           std::optional<Level> lvl,
-                                           LevelType lt, bool isIdxReduc) {
-    assert(env.merger().loop(b) == curr);
-    if (isDenseLT(lt) || isUndefLT(lt)) {
-      if (tid == env.merger().getSynTensorID()) {
-        // Needs loop emitter to set up loop bounds for synthetic tensor too if
-        // there is a loop condition imposed on the synthetic tensor.
-        tidLvls.push_back(env.makeTensorLevel(tid, env.getCurrentDepth()));
-      }
-      needsUniv = true;
-    }
-    if (isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
-        is2OutOf4LT(lt) || isIdxReduc) {
-      // Only when this is a index reduction loop, can the lt be undefined.
-      assert(!isUndefLT(lt) || isIdxReduc);
-      // sparse/singleton levels, or a dense/sparse index reduction loop.
-      tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
-    }
-  });
-
-  env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls);
-
-  // Maintain the universal index only if it is actually
-  // consumed by a subsequent lattice point.
-  if (needsUniv) {
-    for (const LatPointId li : env.set(lts).drop_front())
-      if (!env.merger().hasAnySparse(env.lat(li).simple))
-        return true;
-  }
-  return false;
-}
-
-// Generates dense affine address for encoding.
-static void genConstantDenseAddressFromLevel(CodegenEnv &env,
-                                             OpBuilder &builder, TensorId tid,
-                                             Level startLvl) {
-  // TODO: Handle affine expression on output tensor.
-  linalg::GenericOp op = env.op();
-  assert(tid < op.getNumDpsInputs());
-  OpOperand *input = op.getDpsInputOperands()[tid];
-  const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
-  const auto enc = getSparseTensorEncoding(input->get().getType());
-  if (enc) {
-    const Location loc = op.getLoc();
-    const TensorId tid = env.makeTensorId(input->getOperandNumber());
-    const Level lvlRank = enc.getLvlRank();
-    assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
-    for (Level l = startLvl; l < lvlRank; l++) {
-      AffineExpr lvlExpr = lvlExprs[l];
-      if (enc.isDenseLvl(l) && isa<AffineConstantExpr>(lvlExpr))
-        env.emitter().genDenseAffineAddress(
-            builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
-      else
-        return; // break on first non-dense non-constant level
-    }
-  }
-}
-
-// We can generate address for constant affine expression before any loops
-// starting from the first level as they do not depend on any thing.
-// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
-// levels can be determined before loops.
-static void genInitConstantDenseAddress(CodegenEnv &env,
-                                        RewriterBase &rewriter) {
-  for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
-    genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
-}
-
-/// Return true if the lattices bit can be iterated by a for loop.
-static bool translateBitsToTidLvlPairs(
+static bool getAllTidLvlsInLatPoints(
     CodegenEnv &env, LatPointId li, LoopId curr,
-    SmallVectorImpl<TensorLevel> &tidLvls,
-    SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
+    llvm::function_ref<void(TensorLevel, AffineExpr)> callback) {
   const BitVector &simple = env.lat(li).simple;
   const TensorId outTid = env.merger().getOutTensorID();
   const std::optional<Level> outLvl = env.merger().getLvl(outTid, curr);
@@ -1048,7 +963,7 @@ static bool translateBitsToTidLvlPairs(
                     LevelType lt, bool isIdxReduc) {
         if (simple[b]) {
           if (isIdxReduc) {
-            tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
+            callback(env.makeTensorLevel(tid, *lvl), nullptr);
             numloopCond++;
             return;
           }
@@ -1072,10 +987,10 @@ static bool translateBitsToTidLvlPairs(
             }
           }
           hasNonUnique = !isUniqueLT(lt) || hasNonUnique;
-          tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
+          callback(env.makeTensorLevel(tid, *lvl), nullptr);
           numloopCond++;
         } else if (isDenseLT(lt) || isIdxReduc) {
-          tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
+          callback(env.makeTensorLevel(tid, *lvl), nullptr);
         } else {
           assert(isUndefLT(lt));
           linalg::GenericOp op = env.op();
@@ -1109,7 +1024,7 @@ static bool translateBitsToTidLvlPairs(
                 // level. We need to generate the address according to the
                 // affine expression. This is also the best place we can do it
                 // to avoid putting it inside inner loops.
-                affineTidLvls.emplace_back(env.makeTensorLevel(tid, l), exp);
+                callback(env.makeTensorLevel(tid, l), exp);
               }
             }
           }
@@ -1120,15 +1035,14 @@ static bool translateBitsToTidLvlPairs(
     // Note that we generate dense indices of the output tensor
     // unconditionally, since they may not appear in the lattice, but may be
     // needed for linearized env.
-    tidLvls.push_back(env.makeTensorLevel(outTid, *outLvl));
+    callback(env.makeTensorLevel(outTid, *outLvl), nullptr);
   }
 
   if (numloopCond == 0) {
     // Corner cases where the loop bound is defined by a *unused* operand, in
     // this case, we just generate a dense "fake" loop by iterating over the
     // synthetic tensor.
-    tidLvls.push_back(env.makeTensorLevel(env.merger().getSynTensorID(),
-                                          env.getCurrentDepth()));
+    callback(env.makeTensorLevel(env.merger().getSynTensorID(), curr), nullptr);
     numloopCond++;
   }
   // If we just need to one loop conditions and the conditions is not imposed on
@@ -1136,6 +1050,84 @@ static bool translateBitsToTidLvlPairs(
   return numloopCond == 1 && !hasNonUnique;
 }
 
+/// Starts a loop sequence at given level. Returns true if
+/// the universal loop index must be maintained at this level.
+static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
+                         LoopId curr, LatSetId lts) {
+  assert(!env.getLoopVar(curr));
+  // Emit invariants at this loop sequence level.
+  genInvariants(env, builder, exp, curr, /*isStart=*/true);
+  // Emit access pattern expansion for sparse tensor output.
+  genExpand(env, builder, curr, /*isStart=*/true);
+  // Emit further initialization at this loop sequence level.
+  const LatPointId l0 = env.set(lts)[0];
+
+  SmallVector<TensorLevel> tidLvls;
+  getAllTidLvlsInLatPoints(env, l0, curr, [&](TensorLevel tl, AffineExpr) {
+    tidLvls.emplace_back(tl);
+  });
+
+  env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls);
+
+  // Maintain the universal index only if it is actually
+  // consumed by a subsequent lattice point.
+  for (const LatPointId li : env.set(lts).drop_front())
+    if (!env.merger().hasAnySparse(env.lat(li).simple))
+      return true;
+
+  return false;
+}
+
+// Generates dense affine address for encoding.
+static void genConstantDenseAddressFromLevel(CodegenEnv &env,
+                                             OpBuilder &builder, TensorId tid,
+                                             Level startLvl) {
+  // TODO: Handle affine expression on output tensor.
+  linalg::GenericOp op = env.op();
+  assert(tid < op.getNumDpsInputs());
+  OpOperand *input = op.getDpsInputOperands()[tid];
+  const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
+  const auto enc = getSparseTensorEncoding(input->get().getType());
+  if (enc) {
+    const Location loc = op.getLoc();
+    const TensorId tid = env.makeTensorId(input->getOperandNumber());
+    const Level lvlRank = enc.getLvlRank();
+    assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
+    for (Level l = startLvl; l < lvlRank; l++) {
+      AffineExpr lvlExpr = lvlExprs[l];
+      if (enc.isDenseLvl(l) && isa<AffineConstantExpr>(lvlExpr))
+        env.emitter().genDenseAffineAddress(
+            builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
+      else
+        return; // break on first non-dense non-constant level
+    }
+  }
+}
+
+// We can generate address for constant affine expression before any loops
+// starting from the first level as they do not depend on any thing.
+// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
+// levels can be determined before loops.
+static void genInitConstantDenseAddress(CodegenEnv &env,
+                                        RewriterBase &rewriter) {
+  for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
+    genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
+}
+
+/// Return true if the lattices bit can be iterated by a for loop.
+static bool translateBitsToTidLvlPairs(
+    CodegenEnv &env, LatPointId li, LoopId curr,
+    SmallVectorImpl<TensorLevel> &tidLvls,
+    SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
+  return getAllTidLvlsInLatPoints(env, li, curr,
+                                  [&](TensorLevel tl, AffineExpr exp) {
+                                    if (exp)
+                                      affineTidLvls.emplace_back(tl, exp);
+                                    else
+                                      tidLvls.emplace_back(tl);
+                                  });
+}
+
 /// Starts a single loop in current sequence.
 static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
                                               OpBuilder &builder, LoopId curr,

@PeimingLiu PeimingLiu merged commit d933b88 into llvm:main Jan 2, 2024
@PeimingLiu PeimingLiu deleted the cleanup branch January 2, 2024 23:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants