Skip to content

[mlir][sparse] cleanup ldx/idx/depth/at usage #74654

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 6, 2023
Merged

Conversation

aartbik
Copy link
Contributor

@aartbik aartbik commented Dec 6, 2023

This adds a consistent usage with at for everything that refers to the current loop nesting. This cleans up some redundant legacy code from when we were still using topSort inside sparsifier code.

This adds a consistent usage with `at` for everything
that refers to the current loop nesting. This cleans up
some redundant legacy code from when we were still using
topSort inside sparsifier code.
@llvmbot llvmbot added mlir:sparse Sparse compiler in MLIR mlir labels Dec 6, 2023
@llvmbot
Copy link
Member

llvmbot commented Dec 6, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sparse

Author: Aart Bik (aartbik)

Changes

This adds a consistent usage with at for everything that refers to the current loop nesting. This cleans up some redundant legacy code from when we were still using topSort inside sparsifier code.


Full diff: https://github.com/llvm/llvm-project/pull/74654.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp (+69-70)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index d03e9615d340e..6637a26d0e5af 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -44,23 +44,23 @@ using namespace mlir::sparse_tensor;
 // Sparsifier analysis methods.
 //===----------------------------------------------------------------------===//
 
-/// Determines if affine expression is invariant.
-static bool isInvariantAffine(AffineExpr a, unsigned loopDepth, LoopId ldx,
-                              bool &isAtLoop) {
+/// Returns true iff affine expression is invariant. Sets the
+/// parameter `isAtLoop` when expression just became invariant.
+static bool isInvariantAffine(AffineExpr a, LoopId at, bool &isAtLoop) {
   switch (a.getKind()) {
   case AffineExprKind::DimId: {
     const LoopId i = cast<AffineDimExpr>(a).getPosition();
-    if (i == ldx) {
+    if (i + 1 == at) {
       isAtLoop = true;
-      return true; // invariant at given loop
+      return true; // becomes invariant at current loop
     }
-    return i < loopDepth; // invariant when already generated
+    return i < at; // invariant when already generated
   }
   case AffineExprKind::Add:
   case AffineExprKind::Mul: {
     auto binOp = cast<AffineBinaryOpExpr>(a);
-    return isInvariantAffine(binOp.getLHS(), loopDepth, ldx, isAtLoop) &&
-           isInvariantAffine(binOp.getRHS(), loopDepth, ldx, isAtLoop);
+    return isInvariantAffine(binOp.getLHS(), at, isAtLoop) &&
+           isInvariantAffine(binOp.getRHS(), at, isAtLoop);
   }
   default: {
     assert(isa<AffineConstantExpr>(a));
@@ -126,8 +126,8 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
     if (coefficient <= 0)
       return false;
 
-    const LoopId ldx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
-    if (!isUndefLT(merger.getLvlType(tensor, ldx)))
+    const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
+    if (!isUndefLT(merger.getLvlType(tensor, idx)))
       return false; // used more than once, e.g., A[i][i]
 
     // TODO: Generalizes the following two cases. A[i] (with trivial index
@@ -135,14 +135,14 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
     // not necessarily need to differentiate them.
     if (!isSubExp) {
       assert(coefficient == 1);
-      merger.setLevelAndType(tensor, ldx, lvl, lt);
+      merger.setLevelAndType(tensor, idx, lvl, lt);
     }
 
     if (isSubExp) {
       // The current loops appears in more than one affine expressions on the
       // same tensor. We can not handle this case. e.g., A[i+j][i+k], `i` is
       // used twice.
-      if (merger.hasDependentLvl(ldx, tensor)) {
+      if (merger.hasDependentLvl(idx, tensor)) {
         // TODO: This can be supported by coiterate slices if the loop idx is
         // appeared on affine index for different tensor, or take slice on
         // multiple dimensions when it is on the same tensor.
@@ -154,7 +154,7 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
         // else increase min(d0_1, d0_2).
         return false;
       }
-      merger.setLoopDependentTensorLevel(ldx, tensor, lvl, lt, coefficient);
+      merger.setLoopDependentTensorLevel(idx, tensor, lvl, lt, coefficient);
     }
     return true;
   }
@@ -613,9 +613,9 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) {
   if (kind == TensorExp::Kind::kReduce)
     env.startCustomReduc(e); // enter custom
 
-  Value v0, v1;
   // If either lhs/rhs is a synthetic zero, we infer the type for the zero value
   // based on the type of the other operand.
+  Value v0, v1;
   if (exp.children.e0 != ::mlir::sparse_tensor::detail::kInvalidId &&
       env.exp(exp.children.e0).kind == TensorExp::Kind::kSynZero) {
     v1 = genExp(env, rewriter, exp.children.e1);
@@ -655,21 +655,21 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) {
 
 /// Hoists loop invariant tensor loads for which indices have been exhausted.
 static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
-                          LoopId ldx, bool atStart) {
+                          LoopId at, bool atStart) {
   if (exp == ::mlir::sparse_tensor::detail::kInvalidId)
     return;
   if (env.exp(exp).kind == TensorExp::Kind::kTensor) {
     // Inspect tensor indices.
-    bool isAtLoop = ldx == ::mlir::sparse_tensor::detail::kInvalidId;
     linalg::GenericOp op = env.op();
     OpOperand &t = op->getOpOperand(env.exp(exp).tensor);
     const auto map = op.getMatchingIndexingMap(&t);
     const auto stt = getSparseTensorType(t.get());
     const Level lvlRank = stt.getLvlRank();
     assert(static_cast<Level>(map.getNumResults()) == lvlRank);
+    bool isAtLoop = at == 0; // for scalar tensors
     for (Level l = 0; l < lvlRank; l++) {
       const AffineExpr a = map.getResult(l);
-      if (!isInvariantAffine(a, env.getLoopDepth(), ldx, isAtLoop))
+      if (!isInvariantAffine(a, at, /*out*/ isAtLoop))
         return; // still in play
     }
     // All exhausted at this level (isAtLoop denotes exactly at this LoopId).
@@ -705,8 +705,8 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
       env.startCustomReduc(exp); // enter custom
     const ExprId e0 = env.exp(exp).children.e0;
     const ExprId e1 = env.exp(exp).children.e1;
-    genInvariants(env, builder, e0, ldx, atStart);
-    genInvariants(env, builder, e1, ldx, atStart);
+    genInvariants(env, builder, e0, at, atStart);
+    genInvariants(env, builder, e1, at, atStart);
     if (env.exp(exp).kind == TensorExp::Kind::kReduce)
       env.endCustomReduc(); // exit custom
   }
@@ -782,29 +782,28 @@ static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) {
 
 /// Whether or not the current loop being generated should be parallized (if
 /// possible) according to the configuration.
-static bool shouldTryParallize(CodegenEnv &env, LoopId ldx, bool isOuter,
+static bool shouldTryParallize(CodegenEnv &env, LoopId at,
                                ArrayRef<TensorLevel> tidLvls) {
   linalg::GenericOp op = env.op();
   auto iteratorTypes = op.getIteratorTypesArray();
-  bool isSparse = llvm::any_of(tidLvls, [ldx, &env](TensorLevel tidLvl) {
-    // Queries the LT based on the tensor id and loop idx, as requested by
-    // `CodegenEnv::lt(TensorId, LoopIdx)`. The returned LT from CodegenEnv
+  bool isSparse = llvm::any_of(tidLvls, [at, &env](TensorLevel tidLvl) {
+    // Queries the LT based on the tensor and loop id, as requested by
+    // `CodegenEnv::lt(TensorId, LoopId)`. The returned LT from CodegenEnv
     // should be consistent with the LT indexed by <TensorId, Level>.
-    const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, ldx);
+    const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, at);
     return isCompressedLT(lt) || isSingletonLT(lt);
   });
-  return isParallelFor(env, isOuter, isSparse);
+  return isParallelFor(env, /*isOuter=*/at == 0, isSparse);
 }
 
 /// Emit a loop to coiterate over the list of tensor levels. The generated loop
 /// can either be a for loop or while loop depending on whether there is at most
 /// one sparse level in the list.
 static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
-                                 LoopId idx, ArrayRef<TensorLevel> tidLvls,
+                                 ArrayRef<TensorLevel> tidLvls,
                                  bool tryParallel, bool needsUniv) {
   Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
-    // Construct the while-loop with a parameter for each
-    // index.
+    // Construct while-loop with a parameter for each index.
     return env.emitter().enterCoIterationOverTensorsAtLvls(
         builder, env.op().getLoc(), tidLvls, reduc, tryParallel,
         /*genDedup=*/true, needsUniv);
@@ -817,12 +816,12 @@ static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
 /// singleton iteration or co-iteration over the given conjunction.
 static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId at,
                           bool needsUniv, ArrayRef<TensorLevel> tidLvls) {
-  bool tryParallel = shouldTryParallize(env, at, at == 0, tidLvls);
-  return genCoIteration(env, builder, at, tidLvls, tryParallel, needsUniv);
+  bool tryParallel = shouldTryParallize(env, at, tidLvls);
+  return genCoIteration(env, builder, tidLvls, tryParallel, needsUniv);
 }
 
 /// Generates the induction structure for a while-loop.
-static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, LoopId idx,
+static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder,
                             bool needsUniv) {
   Location loc = env.op().getLoc();
   // Finalize each else branch of all if statements.
@@ -862,7 +861,7 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, LoopId idx,
 }
 
 /// Generates a single if-statement within a while-loop.
-static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx,
+static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId at,
                        LatPointId p) {
   Location loc = env.op().getLoc();
   SmallVector<Type> types;
@@ -880,13 +879,13 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx,
           auto stt = getSparseTensorType(env.op().getInputs()[tid]);
           lt = stt.getLvlType(*lvl);
         }
-        assert(ldx == env.merger().loop(b));
+        assert(at == env.merger().loop(b));
         Value clause;
         if (isCompressedLT(lt) || isSingletonLT(lt) ||
             isLooseCompressedLT(lt) || is2OutOf4LT(lt)) {
           assert(lvl.has_value());
           const Value crd = env.emitter().getCoords()[tid][*lvl];
-          const Value lvar = env.getLoopVar(ldx);
+          const Value lvar = env.getLoopVar(at);
           clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
                                                  crd, lvar);
         } else {
@@ -943,12 +942,12 @@ static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
 /// Starts a loop sequence at given level. Returns true if
 /// the universal loop index must be maintained at this level.
 static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
-                         LoopId idx, LoopId ldx, LatSetId lts) {
-  assert(!env.getLoopVar(idx));
+                         LoopId at, LatSetId lts) {
+  assert(!env.getLoopVar(at));
   // Emit invariants at this loop sequence level.
-  genInvariants(env, builder, exp, ldx, /*atStart=*/true);
+  genInvariants(env, builder, exp, at, /*atStart=*/true);
   // Emit access pattern expansion for sparse tensor output.
-  genExpand(env, builder, idx, /*atStart=*/true);
+  genExpand(env, builder, at, /*atStart=*/true);
   // Emit further intitialization at this loop sequence level.
   const LatPointId l0 = env.set(lts)[0];
   bool needsUniv = false;
@@ -957,7 +956,7 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
   env.merger().foreachTensorLoopId(l0, [&](TensorLoopId b, TensorId tid,
                                            std::optional<Level> lvl,
                                            LevelType lt, bool isIdxReduc) {
-    assert(env.merger().loop(b) == idx);
+    assert(env.merger().loop(b) == at);
     if (isDenseLT(lt) || isUndefLT(lt)) {
       if (tid == env.merger().getSynTensorID()) {
         // Needs loop emitter to set up loop bounds for synthetic tensor too if
@@ -988,6 +987,7 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
   return false;
 }
 
+// Generates dense affine address for encoding.
 static void genConstantDenseAddressFromLevel(CodegenEnv &env,
                                              OpBuilder &builder, TensorId tid,
                                              Level startLvl) {
@@ -1013,30 +1013,30 @@ static void genConstantDenseAddressFromLevel(CodegenEnv &env,
   }
 }
 
+// We can generate address for constant affine expression before any loops
+// starting from the first level as they do not depend on any thing.
+// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
+// levels can be determined before loops.
 static void genInitConstantDenseAddress(CodegenEnv &env,
                                         RewriterBase &rewriter) {
-  // We can generate address for constant affine expression before any loops
-  // starting from the first level as they do not depend on any thing.
-  // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
-  // levels can be determined before loops.
   for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
     genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
 }
 
 /// Return true if the lattices bit can be iterated by a for loop.
 static bool translateBitsToTidLvlPairs(
-    CodegenEnv &env, LatPointId li, LoopId ldx,
+    CodegenEnv &env, LatPointId li, LoopId at,
     SmallVectorImpl<TensorLevel> &tidLvls,
     SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
   const BitVector &simple = env.lat(li).simple;
   const TensorId outTid = env.merger().getOutTensorID();
-  const std::optional<Level> outLvl = env.merger().getLvl(outTid, ldx);
+  const std::optional<Level> outLvl = env.merger().getLvl(outTid, at);
 
   unsigned numloopCond = 0;
   bool hasNonUnique = false;
-  env.merger().foreachTensorLoopId(li, [&, ldx](TensorLoopId b, TensorId tid,
-                                                std::optional<Level> lvl,
-                                                LevelType lt, bool isIdxReduc) {
+  env.merger().foreachTensorLoopId(li, [&, at](TensorLoopId b, TensorId tid,
+                                               std::optional<Level> lvl,
+                                               LevelType lt, bool isIdxReduc) {
     if (simple[b]) {
       if (isIdxReduc) {
         tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
@@ -1089,11 +1089,11 @@ static bool translateBitsToTidLvlPairs(
         if (isa<AffineDimExpr>(exp) || !stt.isDenseLvl(l))
           continue;
 
-        // Constant affine expression are handled in genLoop
+        // Constant affine expression are handled in genLoop.
         if (!isa<AffineConstantExpr>(exp)) {
           bool isAtLoop = false;
-          if (isInvariantAffine(exp, env.getLoopDepth(), ldx, isAtLoop) &&
-              isAtLoop) {
+          assert(at == env.getLoopDepth());
+          if (isInvariantAffine(exp, at + 1, /*out*/ isAtLoop) && isAtLoop) {
             // If the compound affine is invariant and we are right at the
             // level. We need to generate the address according to the
             // affine expression. This is also the best place we can do it
@@ -1105,7 +1105,7 @@ static bool translateBitsToTidLvlPairs(
     }
   });
 
-  if (isDenseLT(env.lt(outTid, ldx))) {
+  if (isDenseLT(env.lt(outTid, at))) {
     // Note that we generate dense indices of the output tensor
     // unconditionally, since they may not appear in the lattice, but may be
     // needed for linearized env.
@@ -1131,9 +1131,9 @@ static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
                                               LatPointId li, bool needsUniv) {
   // The set of tensors + lvls to generate loops on
   SmallVector<TensorLevel> tidLvls;
+
   // The set of dense tensors with non-trivial affine expression that just
-  // becomes invariant and the address shall now be generated at the current
-  // level.
+  // becomes invariant and the address are generated at the current level.
   SmallVector<std::pair<TensorLevel, AffineExpr>> affineTidLvls;
   bool isSingleCond =
       translateBitsToTidLvlPairs(env, li, at, tidLvls, affineTidLvls);
@@ -1161,38 +1161,34 @@ static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
 
 /// Ends a single loop in current sequence. Returns new values for needsUniv.
 static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
-                    LoopId idx, LatPointId li, bool needsUniv,
-                    bool isSingleCond) {
-
+                    LatPointId li, bool needsUniv, bool isSingleCond) {
+  // Either a for-loop or a while-loop that iterates over a slice.
   if (isSingleCond) {
-    // Either a for-loop or a while-loop that iterates over a slice.
     // Any iteration creates a valid lex insert.
     if (env.isReduc() && env.getValidLexInsert())
       env.setValidLexInsert(constantI1(rewriter, env.op().getLoc(), true));
   } else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
     // End a while-loop.
-    finalizeWhileOp(env, rewriter, idx, needsUniv);
+    finalizeWhileOp(env, rewriter, needsUniv);
   } else {
     needsUniv = false;
   }
-
   env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
     env.emitter().exitCurrentLoop(rewriter, env.op().getLoc(), reduc);
     return std::nullopt;
   });
-
   return needsUniv;
 }
 
 /// Ends a loop sequence at given level.
 static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
-                       unsigned idx, unsigned ldx) {
-  assert(!env.getLoopVar(idx));
+                       unsigned at) {
+  assert(!env.getLoopVar(at));
   env.emitter().exitCurrentLoopSeq(builder, env.op().getLoc());
   // Unmark bookkeeping of invariants and loop index.
-  genInvariants(env, builder, exp, ldx, /*atStart=*/false);
+  genInvariants(env, builder, exp, at, /*atStart=*/false);
   // Finalize access pattern expansion for sparse tensor output.
-  genExpand(env, builder, idx, /*atStart=*/false);
+  genExpand(env, builder, at, /*atStart=*/false);
 }
 
 /// Recursively generates code while computing iteration lattices in order
@@ -1200,6 +1196,8 @@ static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
 /// and intersections of sparse iterations spaces.
 static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
                     LoopId at) {
+  assert(at == env.getLoopDepth());
+
   // At each leaf, assign remaining tensor (sub)expression to output tensor.
   if (at == env.getLoopNum()) {
     Value rhs = genExp(env, rewriter, exp);
@@ -1207,13 +1205,12 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
     return;
   }
 
-  // Construct iteration lattices for current loop index, with L0 at top.
-  const LoopId ldx = at == 0 ? sparse_tensor::detail::kInvalidId : at - 1;
+  // Construct iteration lattices for current loop index.
   const LatSetId lts =
       env.merger().optimizeSet(env.merger().buildLattices(exp, at));
 
   // Start a loop sequence.
-  bool needsUniv = startLoopSeq(env, rewriter, exp, at, ldx, lts);
+  bool needsUniv = startLoopSeq(env, rewriter, exp, at, lts);
 
   // Emit a loop for every lattice point L0 >= Li in this loop sequence.
   // We cannot change this to `for (const LatPointId li : env.set(lts))`
@@ -1250,11 +1247,12 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
     }
 
     // End a loop.
-    needsUniv = endLoop(env, rewriter, loop, at, li, needsUniv, isSingleCond);
+    needsUniv = endLoop(env, rewriter, loop, at, needsUniv, isSingleCond);
   }
 
   // End a loop sequence.
-  endLoopSeq(env, rewriter, exp, at, ldx);
+  endLoopSeq(env, rewriter, exp, at);
+  assert(at == env.getLoopDepth());
 }
 
 /// Converts the result computed by the sparse kernel into the required form.
@@ -1309,6 +1307,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
           op, "Loops not yet scheduled, try run --sparse-reinterpret-map "
               "before sparsification.");
     }
+
     // Must have been demapped as well if the generic op is sorted.
     assert(!hasAnyNonIdentityOperandsOrResults(op));
 

@aartbik aartbik merged commit 98ce2de into llvm:main Dec 6, 2023
@aartbik aartbik deleted the bik branch December 6, 2023 21:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants