Skip to content

[mlir][sparse] initialize slice-driven loop-related fields in one place #76099

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 20, 2023

Conversation

PeimingLiu
Copy link
Member

No description provided.

@llvmbot llvmbot added mlir:sparse Sparse compiler in MLIR mlir labels Dec 20, 2023
@llvmbot
Copy link
Member

llvmbot commented Dec 20, 2023

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Peiming Liu (PeimingLiu)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/76099.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp (+2-12)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp (+62-56)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h (+11-16)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp
index 4bd3af2d3f2f6a..d3de55e4d59bd8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp
@@ -85,7 +85,6 @@ void CodegenEnv::startEmit() {
     for (Level lvl = 0; lvl < lvlRank; lvl++)
       sortDependentLoops(latticeMerger.getDependentLoops(tid, lvl));
   }
-
   loopEmitter.initialize(
       tensors,
       StringAttr::get(linalgOp.getContext(),
@@ -95,17 +94,8 @@ void CodegenEnv::startEmit() {
       // TODO: compute the map and pass it to loop emitter directly instead of
       // passing in a callback.
       /*dependentLvlGetter=*/
-      [this](TensorId t,
-             Level lvl) -> std::vector<std::pair<TensorLevel, unsigned>> {
-        // Translates from a list of loop indices to a list of [tid, lvl] pair.
-        std::vector<LoopCoeffPair> &rLoops = merger().getDependentLoops(t, lvl);
-        std::vector<std::pair<TensorLevel, unsigned>> ret;
-        ret.reserve(rLoops.size());
-        for (auto [loop, coeff] : rLoops) {
-          TensorLevel tl = makeTensorLevel(merger().getLoopDefiningLvl(loop));
-          ret.emplace_back(tl, coeff);
-        };
-        return ret;
+      [this](TensorId t, Level lvl) -> std::vector<LoopCoeffPair> {
+        return merger().getDependentLoops(t, lvl);
       });
 }
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index d60b6ccd732167..607a0cf12f70db 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -391,13 +391,18 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
                                  /*posTupleNum=*/Value(), std::nullopt, 0);
     if (dimGetter && !isSynTensor(tid)) {
       for (Level l = 0; l < lvlRank; l++) {
-        dependentLvlMap[tid][l] = dimGetter(tid, l);
+        std::vector<std::pair<LoopId, unsigned>> deps = dimGetter(tid, l);
+        // Sort the loop by order.
+        std::sort(deps.begin(), deps.end(),
+                  [](auto &lhs, auto &rhs) { return lhs.first < rhs.first; });
+
+        dependentLvlMap[tid][l] = std::move(deps);
         unsigned depends = dependentLvlMap[tid][l].size();
         if (depends == 0)
           continue;
-        sliceMeta[tid][l].assign(depends, std::make_pair(nullptr, 0));
+        sliceMeta[tid][l].reserve(depends);
         // We need `depends - 1` slices to fully reduce the affine expression.
-        slicePosBuffer[tid][l].assign(depends - 1, nullptr);
+        slicePosBuffer[tid][l].reserve(depends - 1);
       }
     }
   }
@@ -487,35 +492,70 @@ void LoopEmitter::initializeLoopEmit(
     // hoist the code ouside if-conditions.
   }
 
-  Type indexType = builder.getIndexType();
-  Value c0 = constantZero(builder, loc, indexType);
+  initSliceDriven(builder, loc);
+}
+
+void LoopEmitter::initSliceDriven(OpBuilder &builder, Location loc) {
+  Value c0 = C_IDX(0);
   for (TensorId t = 0, e = tensors.size(); t < e; t++) {
     auto rtp = dyn_cast<RankedTensorType>(tensors[t].getType());
     if (!rtp)
       continue;
 
     Level lvlRank = SparseTensorType(rtp).getLvlRank();
+
+    // Compute the dependency reduction order.
+    auto remDepStack = dependentLvlMap;
+    std::vector<std::tuple<LoopId, TensorId, Level>> depRedOrder;
     for (Level lvl = 0; lvl < lvlRank; lvl++) {
-      if (!dependentLvlMap[t][lvl].empty()) {
-        ArrayRef<std::pair<TensorLevel, unsigned>> depLvls =
-            dependentLvlMap[t][lvl];
-        // Needs at least two operands to form a non-trivial affine expression.
-        assert(depLvls.size() == sliceMeta[t][lvl].size());
-
-        Value size = c0;
-        for (int e = depLvls.size() - 1; e >= 0; e--) {
-          auto [dt, dl] = unpackTensorLevel(depLvls[e].first);
-          unsigned stride = depLvls[e].second;
-          Value stridedSize = lvlSizes[dt][dl];
-          if (stride != 1)
-            stridedSize = MULI(stridedSize, C_IDX(stride));
-          size = ADDI(size, stridedSize);
-          sliceMeta[t][lvl][e] = std::make_pair(size, stride);
-        }
+      // Reverse queue into a stack.
+      std::reverse(remDepStack[t][lvl].begin(), remDepStack[t][lvl].end());
+      for (auto [loop, coeff] : dependentLvlMap[t][lvl])
+        depRedOrder.emplace_back(std::make_tuple(loop, t, lvl));
+    }
+
+    if (depRedOrder.empty())
+      continue;
+    std::sort(depRedOrder.begin(), depRedOrder.end(),
+              [](auto &l, auto &r) { return std::get<0>(l) < std::get<0>(r); });
+
+    for (auto [loop, t, lvl] : depRedOrder) {
+      std::pair<LoopId, unsigned> curDep = remDepStack[t][lvl].back();
+      assert(curDep.first == loop);
+      Value size = c0;
+      for (auto [loop, stride] : remDepStack[t][lvl]) {
+        // The synthetic tensor high defines the loop upper bound.
+        Value loopHi = highs[getSynTensorId()][loop];
+        size = ADDI(size, MULI(loopHi, C_IDX(stride)));
       }
+      sliceMeta[t][lvl].emplace_back(size, curDep.second);
+      remDepStack[t][lvl].pop_back();
+
+      // Generate caches required to fast compute next-non-empty slices with
+      // increasing offset for slice-base loop.
+      // We do not need cache for dense levels.
+      if (!remDepStack[t][lvl].empty() && !isDenseLT(lvls[t][lvl]->getLT())) {
+        Value cnt = C_IDX(1);
+        for (int preLvl = lvl - 1; preLvl >= 0; preLvl--) {
+          if (remDepStack[t][preLvl].empty())
+            break;
+          assert(remDepStack[t][preLvl].size() == 1 && "Not implemented");
+          auto [loop, stride] = remDepStack[t][preLvl].back();
+          assert(stride == 1 && "Not yet implemented");
+          // Accumlates the size required to cache the pLo for the slice.
+          // E.g., if we want to cache the pIdx for slice<d0xd1xf64> on the
+          // second level. We at most need to a memref<d0xindex>.
+          //
+          // NOTE: this is apperantly an over-approximation when the previous
+          // level is compressed, and we can compute a precise memory size
+          // inside the loops. But that would also requires us to allocate/free
+          // memorys in loops.
+          cnt = MULI(highs[getSynTensorId()][loop], cnt);
+        }
+        slicePosBuffer[t][lvl].push_back(allocSlicePosBuf(builder, loc, cnt));
+      } // else fully resolved.
     }
   }
-  localInsertPos = builder.getInsertionPoint()->getPrevNode();
 }
 
 void LoopEmitter::categorizeLoopCondition(
@@ -1878,9 +1918,6 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
   // simple dim expression in between).
   assert(lvl == *sliceStack[tid].back().slicedOnLvl + 1);
 
-  // Check slice stack integrity.
-  assert(slicePosBuffer[tid][lvl - 1].size() == sliceStack[tid].back().depth);
-
   SmallVector<const SliceInfo *> unResSlices;
   std::optional<std::pair<TensorId, Level>> firstResLvl;
   for (Level curLvl = lvl; curLvl >= 1; curLvl--) {
@@ -2006,37 +2043,6 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
   if (baseEnc.isSlice())
     llvm_unreachable("TODO: not yet implemented");
 
-  // Generate caches required to fast compute next-non-empty slices with
-  // increasing offset for slice-base loop.
-  // We do not need cache for dense levels.
-  if (slicePosBuffer[tid][lvl][0] == nullptr && !isDenseLT(lvlType)) {
-    OpBuilder::InsertionGuard guard(builder);
-    // The buffer can be reused, and the size is loop invariant: it only
-    // depends on the iteration graph's toposort.
-    builder.setInsertionPointAfter(localInsertPos);
-    Value tupleCnt = C_IDX(1);
-    // Accumlates the size required to cache the pLo for the slice.
-    // E.g., if we want to cache the pIdx for slice<d0xd1xf64> on the second
-    // level. We at most need to a memref<d0xindex>.
-    // NOTE: this is apperantly an over-approximation when the previous
-    // level is compressed, and we can compute a precise memory size
-    // inside the loops. But that would also requires us to allocate/free
-    // memorys in loops.
-    // TODO: Maybe using allocaScopeOp inside the loop to resolve the issue?
-    for (Level curLevel = lvl;
-         curLevel >= 1 && !lvlFullyResolved(tid, curLevel - 1); curLevel--) {
-      // We only handle cases when all the previously unresolved levels are
-      // fully reduced.
-      assert(depFullyReduced(tid, curLevel - 1));
-      assert(!sliceMeta[tid][curLevel - 1].empty());
-      auto [sz, stride] = sliceMeta[tid][curLevel - 1].back();
-      assert(stride == 1 && "Not yet implemented");
-      tupleCnt = MULI(tupleCnt, sz);
-    }
-    for (Value &cache : slicePosBuffer[tid][lvl])
-      cache = allocSlicePosBuf(builder, loc, tupleCnt);
-  }
-
   if (sliceInfo.isInitialTensor() ||
       (lvl >= 1 && lvlFullyResolved(tid, lvl - 1))) {
     // First level or previous level has been full resolved.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index eb577ee4acefef..5cd6136743a85f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -66,19 +66,15 @@ class LoopEmitter {
   // Map from [tid, lvl] to a list of dependent [tidlvl, coeffecient] for
   // subscript expressions on sparse tensors.
   //
-  // E.g., for affine index (2 * d0 + d1), it depends on two tidlvls that
-  // defines d0 and d1 (for affine expression reduction) and uses 2 and 1 for
-  // cofficients on d0, d1 respectively.
-  // If the list is empty, it means that there is no affine expression on the
-  // input [tid, lvl].
+  // E.g., for affine index (2 * d0 + d1), it depends on loop d0 and d1 (for
+  // affine expression reduction) and uses 2 and 1 for cofficients on d0, d1
+  // respectively. If the list is empty, it means that there is no affine
+  // expression on the input [tid, lvl].
   //
-  // NOTE: The caller is responsible to ensure that the order of the returned
-  // list to be consistent with the topological order of the iteration graph,
-  // otherwise the loop emitter might reduce a wrong dependent index variable
-  // when generating slice-driven loops.
+  // NOTE: LoopEmitter assumes that the loop id is consistent with the loop
+  // order, i.e., loop `d0` will be generated before loop `d1`.
   using DependentLvlGetter =
-      function_ref<std::vector<std::pair<TensorLevel, unsigned>>(TensorId,
-                                                                 Level)>;
+      function_ref<std::vector<std::pair<LoopId, unsigned>>(TensorId, Level)>;
 
   LoopEmitter() = default;
 
@@ -534,6 +530,8 @@ class LoopEmitter {
   // Slice-driven loop related methods.
   //
 
+  void initSliceDriven(OpBuilder &builder, Location loc);
+
   /// Retrieves the most recent slice on lvl. To reduce affine expression like
   /// d0 + d1 + d2, we need two slices (one of size d1 + d2, and the other of
   /// size d2). This methods returns the latter slice (of size d2).
@@ -621,9 +619,6 @@ class LoopEmitter {
   bool hasOutput;
   bool isSparseOut;
 
-  /// The insertion point to allocate top level local variables.
-  Operation *localInsertPos;
-
   //
   // Fields which have `numTensor` many entries.
   //
@@ -645,7 +640,7 @@ class LoopEmitter {
   std::vector<std::vector<Value>> highs;
   std::vector<std::vector<Value>> lvlSizes;
   std::vector<std::vector<std::unique_ptr<SparseTensorLevel>>> lvls;
-  std::vector<Value> valBuffer;                       // to_value
+  std::vector<Value> valBuffer; // to_value
 
   //
   // Slice-driven loops related fields.
@@ -659,7 +654,7 @@ class LoopEmitter {
 
   // Map from [tid, level] to a list of dependent [tidlevel, coefficient].
   // See comments for `DependentLvlGetter`.
-  std::vector<std::vector<std::vector<std::pair<TensorLevel, unsigned>>>>
+  std::vector<std::vector<std::vector<std::pair<LoopId, unsigned>>>>
       dependentLvlMap;
 
   // The cached position buffer for the slices, they serve the same purpose as

@PeimingLiu PeimingLiu merged commit cf4dd91 into llvm:main Dec 20, 2023
@PeimingLiu PeimingLiu deleted the cleanup-slice branch December 20, 2023 22:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants