Skip to content

[mlir][sparse] optimize memory loads to SSA values when generating sp… #74787

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 8, 2023

Conversation

PeimingLiu
Copy link
Member

…arse conv.

@llvmbot llvmbot added mlir:sparse Sparse compiler in MLIR mlir labels Dec 7, 2023
@llvmbot
Copy link
Member

llvmbot commented Dec 7, 2023

@llvm/pr-subscribers-mlir-sparse

Author: Peiming Liu (PeimingLiu)

Changes

…arse conv.


Patch is 51.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/74787.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp (+41-99)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h (+3-10)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir (+179-194)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index 75121b5e3ce2e..26d6ea908cf38 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -148,39 +148,29 @@ static Value genSparseReducedAffineCond(OpBuilder &builder, Location loc,
 // Helper functions that load/store into the position buffer for slice-driven
 // loops.
 // The sliced pointer buffer is orgnized as:
-// [size, curPtr] (two metadata) + [[pLo0, pLo1, pLo2, ...],
-//                                  [pHi0, pHi1, pHi2, ...],
-//                                  [pNx0, pNx1, pNx2, ...]]
+// [[pLo0, pLo1, pLo2, ...],
+//  [pHi0, pHi1, pHi2, ...],
+//  [pNx0, pNx1, pNx2, ...]]
 static Value allocSlicePosBuf(OpBuilder &builder, Location loc,
                               Value tupleCnt) {
   Value bufSz = MULI(tupleCnt, C_IDX(kSliceIterWidth));
   // Additional two metadata {memSize, idx} at head.
-  bufSz = ADDI(bufSz, C_IDX(2));
   return genAlloca(builder, loc, bufSz, builder.getIndexType());
 }
-// TODO: We should use SSA value for it.
-// Gets and sets metadata.
-static Value loadSlicePosPtr(OpBuilder &builder, Location loc, Value sPosBuf) {
-  return genIndexLoad(builder, loc, sPosBuf, C_IDX(1));
-}
-static void updateSlicePosPtr(OpBuilder &builder, Location loc, Value sPosBuf,
-                              Value pPtr) {
-  builder.create<memref::StoreOp>(loc, pPtr, sPosBuf, C_IDX(1));
-}
 
 // Gets and sets position values for slice-driven loops.
 enum class SlicePosKind { kLo, kHi, kNext };
 static Value getSlicePosIdx(OpBuilder &builder, Location loc, Value posBuf,
                             Value tupleIdx, SlicePosKind posKind) {
   Value dim = builder.create<memref::DimOp>(loc, posBuf, C_IDX(0));
-  Value tupleCnt = DIVUI(SUBI(dim, C_IDX(2)), C_IDX(kSliceIterWidth));
+  Value tupleCnt = DIVUI(dim, C_IDX(kSliceIterWidth));
   switch (posKind) {
   case SlicePosKind::kLo:
-    return ADDI(tupleIdx, C_IDX(2));
+    return tupleIdx;
   case SlicePosKind::kHi:
-    return ADDI(tupleIdx, ADDI(tupleCnt, C_IDX(2)));
+    return ADDI(tupleIdx, tupleCnt);
   case SlicePosKind::kNext:
-    return ADDI(tupleIdx, ADDI(tupleCnt, ADDI(tupleCnt, C_IDX(2))));
+    return ADDI(tupleIdx, MULI(tupleCnt, C_IDX(2)));
   }
   llvm_unreachable("unexpected kind");
 }
@@ -344,6 +334,9 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
   this->dependentLvlMap.assign(
       numTensors, std::vector<std::vector<std::pair<TensorLevel, unsigned>>>());
   this->slicePosBuffer.assign(numTensors, std::vector<std::vector<Value>>());
+  this->sliceTupleNxStartIdx.assign(numTensors, std::vector<Value>());
+  this->sliceTupleFwdCnt.assign(numTensors, std::vector<Value>());
+  this->trivialSlice.assign(numTensors, std::vector<bool>());
   this->sliceMeta.assign(
       numTensors, std::vector<std::vector<std::pair<Value, unsigned>>>());
   this->sliceStack.assign(numTensors, std::vector<SliceInfo>());
@@ -394,6 +387,9 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
     dependentLvlMap[tid].assign(
         lvlRank, std::vector<std::pair<TensorLevel, unsigned>>());
     slicePosBuffer[tid].assign(lvlRank, std::vector<Value>());
+    sliceTupleNxStartIdx[tid].assign(lvlRank, Value());
+    sliceTupleFwdCnt[tid].assign(lvlRank, Value());
+    trivialSlice[tid].assign(lvlRank, false);
     sliceMeta[tid].assign(lvlRank, std::vector<std::pair<Value, unsigned>>());
     sliceStack[tid].emplace_back(/*minCrd=*/Value(),
                                  /*offset=*/Value(), /*isNonEmpty*/ Value(),
@@ -806,6 +802,7 @@ std::optional<Value> LoopEmitter::genWhileLoopBody(OpBuilder &builder,
     assert(ivs.size() == 1);
     // Coord is the relative offset related to its parents.
     assert(sliceStack[tid].back().depth == 1 && "TODO: not yet implement");
+    sliceTupleFwdCnt[tid][lvl] = SUBI(ivs[0], posits[tid][lvl]);
     // Update c = absOffset[lvl][depth] - absOffset[lvl][depth - 1]
     Value posit = ivs[0];
     Value crdBuf = coordinatesBuffers[tid][lvl];
@@ -1324,6 +1321,12 @@ void LoopEmitter::enterTensorsAtDenseLvls(
       } else {
         posits[tid][lvl] =
             genAddress(builder, loc, tid, lvl, ADDI(info.offset, iv));
+        Value fwdCnt = lvl == 0 || trivialSlice[tid][lvl]
+                           ? C_IDX(0)
+                           : sliceTupleFwdCnt[tid][lvl - 1];
+        Value sz = sliceMeta[tid][lvl].back().first;
+        Value mul = MULI(fwdCnt, sz);
+        sliceTupleFwdCnt[tid][lvl] = ADDI(mul, iv);
       }
       levelReducedDep[tid][lvl]++;
     } else {
@@ -1357,13 +1360,7 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
       assert(isDenseLT(lvlTypes[tid][lvl]));
       assert(*info.slicedOnLvl == lvl);
       (void)reduced;
-      // Resets slices pointers as the resolved slices are invalidated after we
-      // moves forward to the next slice.
-      invalidateSliceIterIdx(rewriter, loc, tid, lvl);
       info.minCrd = info.offset = info.isNonEmpty = Value();
-    } else {
-      forwardsReducedSliceLevelTreeIt(rewriter, loc, tid, lvl,
-                                      constantIndex(rewriter, loc, 1));
     }
     levelReducedDep[tid][lvl]--;
   }
@@ -1443,54 +1440,6 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
   }
 }
 
-void LoopEmitter::forwardsReducedSliceLevelTreeIt(OpBuilder &builder,
-                                                  Location loc, TensorId tid,
-                                                  Level rootLvl, Value fcnt) {
-
-  auto stt = getSparseTensorType(tensors[tid]);
-
-  // Finds a [Lvl, leafLvl) range, and all level in between are fully reduced
-  // sparse levels (but not resolved). Since we forward an iterator at higher
-  // level of the tree, the subtree need to be pruned.
-  Level leafLvl = rootLvl + 1;
-  while (leafLvl < stt.getLvlRank() && depFullyReduced(tid, leafLvl) &&
-         !stt.isDenseLvl(leafLvl)) {
-    leafLvl++;
-  }
-
-  Level curLvl = rootLvl + 1;
-  Value nxPosPtr = nullptr;
-  if (curLvl < leafLvl) {
-    assert(!isDenseLT(lvlTypes[tid][curLvl]));
-    // The first compressed level, setting up the position pointer for it.
-    Value sPosBuf = slicePosBuffer[tid][curLvl].back();
-    // One step forwards in the parent level result in forwarding one `segment`
-    // in the child sparse level.
-    Value pPosPtr = loadSlicePosPtr(builder, loc, sPosBuf); // previous ptr
-    Value cPosPtr = ADDI(fcnt, pPosPtr);                    // current ptr
-    updateSlicePosPtr(builder, loc, sPosBuf, cPosPtr);
-    // Loads the position pointer start for next level.
-    nxPosPtr =
-        loadSlicePos(builder, loc, sPosBuf, cPosPtr, SlicePosKind::kNext);
-    curLvl++;
-  }
-
-  // TODO: This is not always needed, but we did it unconditionally for now for
-  // simplicity.
-  // It is only needed when `curLvl` is forwarded without traversing its child
-  // level (e.g., the level is in a conjunctive lattices and got pruned), such
-  // that the position pointer is not forwarded inside the loop.
-  for (; curLvl < leafLvl; curLvl++) {
-    assert(nxPosPtr);
-    if (!isDenseLT(lvlTypes[tid][curLvl])) {
-      Value sPosBuf = slicePosBuffer[tid][curLvl].back();
-      updateSlicePosPtr(builder, loc, sPosBuf, nxPosPtr);
-      nxPosPtr =
-          loadSlicePos(builder, loc, sPosBuf, nxPosPtr, SlicePosKind::kNext);
-    }
-  }
-}
-
 void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
                                 MutableArrayRef<Value> reduc) {
   const LoopInfo &loopInfo = loopStack.back();
@@ -1540,13 +1489,6 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
       forwarded = CMPI(eq, coords[tid][lvl], iv);
       operands.push_back(SELECT(forwarded, nxPos, pos));
     }
-    {
-      OpBuilder::InsertionGuard guard(builder);
-      auto ifOp = builder.create<scf::IfOp>(loc, TypeRange{}, forwarded,
-                                            /*else=*/false);
-      builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-      forwardsReducedSliceLevelTreeIt(builder, loc, tid, lvl, one);
-    }
     // The coordinate is invalid now.
     coords[tid][lvl] = nullptr;
 
@@ -1916,8 +1858,7 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
     pHi = genIndexLoad(builder, loc, positionsBuffers[tid][lvl],
                        ADDI(posits[tid][lvl - 1], c1));
   }
-  // Fills out pIdxBuffer[tid][lvl][0] with [0, pLo, pHi]
-  updateSlicePosPtr(builder, loc, sPtrBuf, c0);
+  // Fills out pIdxBuffer[tid][lvl][0] with [pLo, pHi]
   updateSlicePos(builder, loc, sPtrBuf, pLo, c0, SlicePosKind::kLo);
   updateSlicePos(builder, loc, sPtrBuf, pHi, c0, SlicePosKind::kHi);
   // Slice over a resolved parent, we only need one pair of pos hi and lo to
@@ -2056,8 +1997,6 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
   Value isNonEmpty = result[0];
   Value minCrd = result[1];
   // Two metadata [memSize, idx].
-  // TODO: Can use an SSA value for these two metadata
-  updateSlicePosPtr(builder, loc, sPtrBuf, c0);
   // FIXME: we need the relative offset related to the base slice.
   Value absOffset = offsetFromMinCoord(builder, loc, minCrd, remSz, isNonEmpty);
   sliceStack[tid].emplace_back(minCrd, absOffset, isNonEmpty, result[2], lvl,
@@ -2066,16 +2005,30 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
 
 bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
                                 Level lvl) {
+  Value curLvlIdx = C_IDX(0);
   if (depFullyReduced(tid, lvl)) {
-    // Do not need to prepare for slice driven loop on dense level after it is
-    // fully reduced.
+    if (lvl == 0 || trivialSlice[tid][lvl]) {
+      sliceTupleNxStartIdx[tid][lvl] = C_IDX(0);
+    } else {
+      if (isDenseLT(lvlTypes[tid][lvl])) {
+        sliceTupleNxStartIdx[tid][lvl] = sliceTupleNxStartIdx[tid][lvl - 1];
+      } else {
+        assert(isCompressedLT(lvlTypes[tid][lvl]));
+        curLvlIdx = ADDI(sliceTupleNxStartIdx[tid][lvl - 1],
+                         sliceTupleFwdCnt[0][lvl - 1]);
+        sliceTupleNxStartIdx[tid][lvl] =
+            loadSlicePos(builder, loc, slicePosBuffer[tid][lvl].back(),
+                         curLvlIdx, SlicePosKind::kNext);
+      }
+    }
     if (isDenseLT(lvlTypes[tid][lvl]))
       return true;
+
+    Value sPosBuf = slicePosBuffer[tid][lvl].back();
     // If constraints on the tensor is fully resolved. We do not need to
     // generates slice begin any more, instead we fall back to TACO-based
     // algorithm to (co)iterates over the slice.
-    Value sPosBuf = slicePosBuffer[tid][lvl].back();
-    Value tupleIdx = loadSlicePosPtr(builder, loc, sPosBuf);
+    Value tupleIdx = curLvlIdx;
     posits[tid][lvl] =
         loadSlicePos(builder, loc, sPosBuf, tupleIdx, SlicePosKind::kLo);
     highs[tid][lvl] =
@@ -2134,23 +2087,16 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
   if (sliceInfo.isInitialTensor() ||
       (lvl >= 1 && lvlFullyResolved(tid, lvl - 1))) {
     // First level or previous level has been full resolved.
+    trivialSlice[tid][lvl] = true;
     genResolvedSliceBegin(builder, loc, tid, lvl);
   } else {
     // The previous level has not been full resolved.
+    trivialSlice[tid][lvl] = false;
     genUnResolvedSliceBegin(builder, loc, tid, lvl);
   }
   return false;
 }
 
-void LoopEmitter::invalidateSliceIterIdx(OpBuilder &builder, Location loc,
-                                         TensorId tid, Level lvl) {
-  for (unsigned i = 0; i <= lvl; i++) {
-    if (!isDenseLT(lvlTypes[tid][i]) && !dependentLvlMap[tid][i].empty()) {
-      updateSlicePosPtr(builder, loc, slicePosBuffer[tid][i].back(), C_IDX(0));
-    }
-  }
-}
-
 std::tuple<Value, Value, Value>
 LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
                                    TensorId tid, Level lvl) {
@@ -2175,10 +2121,6 @@ LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
   //   isNonEmpty = false;
   //
   Value absOffset = info.offset;
-  // Resets slices pointers as the resolved slices are invalidated after we
-  // moves forward to the next slice.
-  invalidateSliceIterIdx(builder, loc, tid, lvl);
-
   SmallVector<Value, 3> reduc = {info.minCrd, info.isNonEmpty, absOffset};
   Value sPtrBuf = slicePosBuffer[tid][lvl][info.depth - 1];
   Value fastPathP = CMPI(ugt, info.minCrd, absOffset);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index 5e51cb2110fa1..fa8b0076f733b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -453,11 +453,6 @@ class LoopEmitter {
     return tid < lvlTypes.size() && lvl < lvlTypes[tid].size();
   }
 
-  /// Forwards the (conceptual) "tree iterator" when iterating over a fully
-  /// reduced slice created by index-reduction.
-  void forwardsReducedSliceLevelTreeIt(OpBuilder &builder, Location loc,
-                                       TensorId tid, Level lvl, Value fcnt);
-
   /// Prepares loop for iterating over `tensor[lvl]`, under the assumption
   /// that `tensor[0...lvl-1]` loops have already been set up.
   void prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
@@ -610,11 +605,6 @@ class LoopEmitter {
   void genUnResolvedSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
                                Level lvl);
 
-  /// Invalidates the index kept in slice postion buffers (by setting it to
-  /// zero).
-  /// TODO: We should instead use an SSA value for the index.
-  void invalidateSliceIterIdx(OpBuilder &builder, Location loc, TensorId tid,
-                              Level lvl);
   /// Generates code to get the first non-empty slice of tid on lvl.
   /// return true if has already been resolved.
   bool genSliceBegin(OpBuilder &builder, Location loc, TensorId tid, Level lvl);
@@ -683,6 +673,9 @@ class LoopEmitter {
   // But they always starts with the first pidx pointing to coord > slice.offset
   // to avoid iteration from the beginning.
   std::vector<std::vector<std::vector<Value>>> slicePosBuffer;
+  std::vector<std::vector<Value>> sliceTupleNxStartIdx;
+  std::vector<std::vector<Value>> sliceTupleFwdCnt;
+  std::vector<std::vector<bool>> trivialSlice;
 
   // The (size, stride) for each conceptual slice used for index reduction
   // loops.
diff --git a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
index 02cc5d1e2ef34..a3c1e76a3d09a 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
@@ -12,241 +12,226 @@
 // CHECK-SAME:      %[[VAL_1:.*]]: tensor<3x3xi32>) -> tensor<6x6xi32, #sparse> {
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant true
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant -2 : index
-// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 4 : index
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 2 : index
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 8 : index
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 3 : index
 // CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 1 : index
-// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 5 : index
-// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 2 : index
-// CHECK-DAG:       %[[VAL_10:.*]] = arith.constant 0 : index
-// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant 0 : i32
-// CHECK-DAG:       %[[VAL_12:.*]] = arith.constant false
-// CHECK-DAG:       %[[VAL_13:.*]] = tensor.empty() : tensor<6x6xi32, #sparse>
-// CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_15:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_16:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_17:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_18:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xi32, #sparse> to memref<?xi32>
-// CHECK:           %[[VAL_19:.*]] = memref.alloca() : memref<11xindex>
-// CHECK:           %[[VAL_20:.*]] = memref.alloca() : memref<5xindex>
-// CHECK:           %[[VAL_21:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_7]]] : memref<?xindex>
-// CHECK:           memref.store %[[VAL_10]], %[[VAL_20]]{{\[}}%[[VAL_7]]] : memref<5xindex>
-// CHECK:           memref.store %[[VAL_10]], %[[VAL_20]]{{\[}}%[[VAL_9]]] : memref<5xindex>
-// CHECK:           memref.store %[[VAL_21]], %[[VAL_20]]{{\[}}%[[VAL_6]]] : memref<5xindex>
-// CHECK:           %[[VAL_22:.*]] = arith.cmpi ugt, %[[VAL_21]], %[[VAL_10]] : index
-// CHECK:           %[[VAL_23:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_10]]] : memref<?xindex>
-// CHECK:           %[[VAL_24:.*]] = arith.cmpi uge, %[[VAL_23]], %[[VAL_6]] : index
-// CHECK:           %[[VAL_25:.*]] = arith.andi %[[VAL_22]], %[[VAL_24]] : i1
-// CHECK:           %[[VAL_26:.*]] = arith.addi %[[VAL_23]], %[[VAL_3]] : index
-// CHECK:           %[[VAL_27:.*]] = arith.select %[[VAL_25]], %[[VAL_26]], %[[VAL_10]] : index
-// CHECK:           %[[VAL_28:.*]]:3 = scf.while (%[[VAL_29:.*]] = %[[VAL_22]], %[[VAL_30:.*]] = %[[VAL_23]], %[[VAL_31:.*]] = %[[VAL_27]], %[[VAL_32:.*]] = %[[VAL_13]]) : (i1, index, index, tensor<6x6xi32, #sparse>) -> (index, index, tensor<6x6xi32, #sparse>) {
-// CHECK:             scf.condition(%[[VAL_29]]) %[[VAL_30]], %[[VAL_31]], %[[VAL_32]] : index, index, tensor<6x6xi32, #sparse>
+// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 0 : i32
+// CHECK-DAG:       %[[VAL_10:.*]] = arith.constant false
+// CHECK-DAG:       %[[VAL_11:.*]] = tensor.empty() : tensor<6x6xi32, #sparse>
+// CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_15:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_16:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xi32, #sparse> to memref<?xi32>
+// CHECK-DAG:       %[[VAL_17:.*]] = memref.alloca() : memref<9xindex>
+// CHECK-DAG:       %[[VAL_18:.*]] = memref.alloca() : memref<3xindex>
+// CHECK-DAG:       %[[VAL_19:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_7]]] : memref<?xindex>
+// CHECK:           memref.store %[[VAL_8]], %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex>
+// CHECK:           memref.store %[[VAL_19]], %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<3xindex>
+// CHECK:           %[[VAL_20:.*]] = arith.cmpi ugt, %[[VAL_19]], %[[VAL_8]] : index
+// CHECK:           %[[VAL_21:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_8]]] : memref<?xindex>
+// CHECK:           %[[VAL_22:.*]] = arith.cmpi uge, %[[VAL_21]], %[[VAL_6]] : index
+// CHECK:           %[[VAL_23:.*]] = arith.andi %[[VAL_20]], %[[VAL_22]] : i1
+// CHECK:           %[[VAL_24:.*]] = arith.addi %[[VAL_21]], %[[VAL_3]] : index
+// CHECK:           %[[VAL_25:.*]] = arith.select %[[VAL_23]], %[[VAL_24]], %[[VAL_8]] : index
+// CHECK:           %[[VAL_26:.*]]:3 = scf.while (%[[VAL_27:.*]] = %[[VAL_20]], %[[VAL_28:.*]] = %[[VAL_21]], %[[VAL_29:.*]] = %[[VAL_25]], %[[VAL_30:.*]] = %[[VAL_11]]) : (i1, index, index, tensor<6x6xi32, #sparse>) -> (index, index, tensor<6x6xi32, #sparse>) {
+// CHECK:             scf.condition(%[[VAL_27]]) %[[VAL_28]], %[[VAL_29]], %[[VAL_30]] : index, index, tensor<6x6xi32, #sparse>
 // CHECK:           } do {
-// CHECK:           ^bb0(%[[VAL_33:.*]]: index, %[[VAL_34:.*]]: index, %[[VAL_35:.*]]: tensor<6x6xi32, #sparse>):
...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Dec 7, 2023

@llvm/pr-subscribers-mlir

Author: Peiming Liu (PeimingLiu)

Changes

…arse conv.


Patch is 51.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/74787.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp (+41-99)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h (+3-10)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir (+179-194)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index 75121b5e3ce2e..26d6ea908cf38 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -148,39 +148,29 @@ static Value genSparseReducedAffineCond(OpBuilder &builder, Location loc,
 // Helper functions that load/store into the position buffer for slice-driven
 // loops.
 // The sliced pointer buffer is orgnized as:
-// [size, curPtr] (two metadata) + [[pLo0, pLo1, pLo2, ...],
-//                                  [pHi0, pHi1, pHi2, ...],
-//                                  [pNx0, pNx1, pNx2, ...]]
+// [[pLo0, pLo1, pLo2, ...],
+//  [pHi0, pHi1, pHi2, ...],
+//  [pNx0, pNx1, pNx2, ...]]
 static Value allocSlicePosBuf(OpBuilder &builder, Location loc,
                               Value tupleCnt) {
   Value bufSz = MULI(tupleCnt, C_IDX(kSliceIterWidth));
   // Additional two metadata {memSize, idx} at head.
-  bufSz = ADDI(bufSz, C_IDX(2));
   return genAlloca(builder, loc, bufSz, builder.getIndexType());
 }
-// TODO: We should use SSA value for it.
-// Gets and sets metadata.
-static Value loadSlicePosPtr(OpBuilder &builder, Location loc, Value sPosBuf) {
-  return genIndexLoad(builder, loc, sPosBuf, C_IDX(1));
-}
-static void updateSlicePosPtr(OpBuilder &builder, Location loc, Value sPosBuf,
-                              Value pPtr) {
-  builder.create<memref::StoreOp>(loc, pPtr, sPosBuf, C_IDX(1));
-}
 
 // Gets and sets position values for slice-driven loops.
 enum class SlicePosKind { kLo, kHi, kNext };
 static Value getSlicePosIdx(OpBuilder &builder, Location loc, Value posBuf,
                             Value tupleIdx, SlicePosKind posKind) {
   Value dim = builder.create<memref::DimOp>(loc, posBuf, C_IDX(0));
-  Value tupleCnt = DIVUI(SUBI(dim, C_IDX(2)), C_IDX(kSliceIterWidth));
+  Value tupleCnt = DIVUI(dim, C_IDX(kSliceIterWidth));
   switch (posKind) {
   case SlicePosKind::kLo:
-    return ADDI(tupleIdx, C_IDX(2));
+    return tupleIdx;
   case SlicePosKind::kHi:
-    return ADDI(tupleIdx, ADDI(tupleCnt, C_IDX(2)));
+    return ADDI(tupleIdx, tupleCnt);
   case SlicePosKind::kNext:
-    return ADDI(tupleIdx, ADDI(tupleCnt, ADDI(tupleCnt, C_IDX(2))));
+    return ADDI(tupleIdx, MULI(tupleCnt, C_IDX(2)));
   }
   llvm_unreachable("unexpected kind");
 }
@@ -344,6 +334,9 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
   this->dependentLvlMap.assign(
       numTensors, std::vector<std::vector<std::pair<TensorLevel, unsigned>>>());
   this->slicePosBuffer.assign(numTensors, std::vector<std::vector<Value>>());
+  this->sliceTupleNxStartIdx.assign(numTensors, std::vector<Value>());
+  this->sliceTupleFwdCnt.assign(numTensors, std::vector<Value>());
+  this->trivialSlice.assign(numTensors, std::vector<bool>());
   this->sliceMeta.assign(
       numTensors, std::vector<std::vector<std::pair<Value, unsigned>>>());
   this->sliceStack.assign(numTensors, std::vector<SliceInfo>());
@@ -394,6 +387,9 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
     dependentLvlMap[tid].assign(
         lvlRank, std::vector<std::pair<TensorLevel, unsigned>>());
     slicePosBuffer[tid].assign(lvlRank, std::vector<Value>());
+    sliceTupleNxStartIdx[tid].assign(lvlRank, Value());
+    sliceTupleFwdCnt[tid].assign(lvlRank, Value());
+    trivialSlice[tid].assign(lvlRank, false);
     sliceMeta[tid].assign(lvlRank, std::vector<std::pair<Value, unsigned>>());
     sliceStack[tid].emplace_back(/*minCrd=*/Value(),
                                  /*offset=*/Value(), /*isNonEmpty*/ Value(),
@@ -806,6 +802,7 @@ std::optional<Value> LoopEmitter::genWhileLoopBody(OpBuilder &builder,
     assert(ivs.size() == 1);
     // Coord is the relative offset related to its parents.
     assert(sliceStack[tid].back().depth == 1 && "TODO: not yet implement");
+    sliceTupleFwdCnt[tid][lvl] = SUBI(ivs[0], posits[tid][lvl]);
     // Update c = absOffset[lvl][depth] - absOffset[lvl][depth - 1]
     Value posit = ivs[0];
     Value crdBuf = coordinatesBuffers[tid][lvl];
@@ -1324,6 +1321,12 @@ void LoopEmitter::enterTensorsAtDenseLvls(
       } else {
         posits[tid][lvl] =
             genAddress(builder, loc, tid, lvl, ADDI(info.offset, iv));
+        Value fwdCnt = lvl == 0 || trivialSlice[tid][lvl]
+                           ? C_IDX(0)
+                           : sliceTupleFwdCnt[tid][lvl - 1];
+        Value sz = sliceMeta[tid][lvl].back().first;
+        Value mul = MULI(fwdCnt, sz);
+        sliceTupleFwdCnt[tid][lvl] = ADDI(mul, iv);
       }
       levelReducedDep[tid][lvl]++;
     } else {
@@ -1357,13 +1360,7 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
       assert(isDenseLT(lvlTypes[tid][lvl]));
       assert(*info.slicedOnLvl == lvl);
       (void)reduced;
-      // Resets slices pointers as the resolved slices are invalidated after we
-      // moves forward to the next slice.
-      invalidateSliceIterIdx(rewriter, loc, tid, lvl);
       info.minCrd = info.offset = info.isNonEmpty = Value();
-    } else {
-      forwardsReducedSliceLevelTreeIt(rewriter, loc, tid, lvl,
-                                      constantIndex(rewriter, loc, 1));
     }
     levelReducedDep[tid][lvl]--;
   }
@@ -1443,54 +1440,6 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
   }
 }
 
-void LoopEmitter::forwardsReducedSliceLevelTreeIt(OpBuilder &builder,
-                                                  Location loc, TensorId tid,
-                                                  Level rootLvl, Value fcnt) {
-
-  auto stt = getSparseTensorType(tensors[tid]);
-
-  // Finds a [Lvl, leafLvl) range, and all level in between are fully reduced
-  // sparse levels (but not resolved). Since we forward an iterator at higher
-  // level of the tree, the subtree need to be pruned.
-  Level leafLvl = rootLvl + 1;
-  while (leafLvl < stt.getLvlRank() && depFullyReduced(tid, leafLvl) &&
-         !stt.isDenseLvl(leafLvl)) {
-    leafLvl++;
-  }
-
-  Level curLvl = rootLvl + 1;
-  Value nxPosPtr = nullptr;
-  if (curLvl < leafLvl) {
-    assert(!isDenseLT(lvlTypes[tid][curLvl]));
-    // The first compressed level, setting up the position pointer for it.
-    Value sPosBuf = slicePosBuffer[tid][curLvl].back();
-    // One step forwards in the parent level result in forwarding one `segment`
-    // in the child sparse level.
-    Value pPosPtr = loadSlicePosPtr(builder, loc, sPosBuf); // previous ptr
-    Value cPosPtr = ADDI(fcnt, pPosPtr);                    // current ptr
-    updateSlicePosPtr(builder, loc, sPosBuf, cPosPtr);
-    // Loads the position pointer start for next level.
-    nxPosPtr =
-        loadSlicePos(builder, loc, sPosBuf, cPosPtr, SlicePosKind::kNext);
-    curLvl++;
-  }
-
-  // TODO: This is not always needed, but we did it unconditionally for now for
-  // simplicity.
-  // It is only needed when `curLvl` is forwarded without traversing its child
-  // level (e.g., the level is in a conjunctive lattices and got pruned), such
-  // that the position pointer is not forwarded inside the loop.
-  for (; curLvl < leafLvl; curLvl++) {
-    assert(nxPosPtr);
-    if (!isDenseLT(lvlTypes[tid][curLvl])) {
-      Value sPosBuf = slicePosBuffer[tid][curLvl].back();
-      updateSlicePosPtr(builder, loc, sPosBuf, nxPosPtr);
-      nxPosPtr =
-          loadSlicePos(builder, loc, sPosBuf, nxPosPtr, SlicePosKind::kNext);
-    }
-  }
-}
-
 void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
                                 MutableArrayRef<Value> reduc) {
   const LoopInfo &loopInfo = loopStack.back();
@@ -1540,13 +1489,6 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
       forwarded = CMPI(eq, coords[tid][lvl], iv);
       operands.push_back(SELECT(forwarded, nxPos, pos));
     }
-    {
-      OpBuilder::InsertionGuard guard(builder);
-      auto ifOp = builder.create<scf::IfOp>(loc, TypeRange{}, forwarded,
-                                            /*else=*/false);
-      builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-      forwardsReducedSliceLevelTreeIt(builder, loc, tid, lvl, one);
-    }
     // The coordinate is invalid now.
     coords[tid][lvl] = nullptr;
 
@@ -1916,8 +1858,7 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
     pHi = genIndexLoad(builder, loc, positionsBuffers[tid][lvl],
                        ADDI(posits[tid][lvl - 1], c1));
   }
-  // Fills out pIdxBuffer[tid][lvl][0] with [0, pLo, pHi]
-  updateSlicePosPtr(builder, loc, sPtrBuf, c0);
+  // Fills out pIdxBuffer[tid][lvl][0] with [pLo, pHi]
   updateSlicePos(builder, loc, sPtrBuf, pLo, c0, SlicePosKind::kLo);
   updateSlicePos(builder, loc, sPtrBuf, pHi, c0, SlicePosKind::kHi);
   // Slice over a resolved parent, we only need one pair of pos hi and lo to
@@ -2056,8 +1997,6 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
   Value isNonEmpty = result[0];
   Value minCrd = result[1];
   // Two metadata [memSize, idx].
-  // TODO: Can use an SSA value for these two metadata
-  updateSlicePosPtr(builder, loc, sPtrBuf, c0);
   // FIXME: we need the relative offset related to the base slice.
   Value absOffset = offsetFromMinCoord(builder, loc, minCrd, remSz, isNonEmpty);
   sliceStack[tid].emplace_back(minCrd, absOffset, isNonEmpty, result[2], lvl,
@@ -2066,16 +2005,30 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
 
 bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
                                 Level lvl) {
+  Value curLvlIdx = C_IDX(0);
   if (depFullyReduced(tid, lvl)) {
-    // Do not need to prepare for slice driven loop on dense level after it is
-    // fully reduced.
+    if (lvl == 0 || trivialSlice[tid][lvl]) {
+      sliceTupleNxStartIdx[tid][lvl] = C_IDX(0);
+    } else {
+      if (isDenseLT(lvlTypes[tid][lvl])) {
+        sliceTupleNxStartIdx[tid][lvl] = sliceTupleNxStartIdx[tid][lvl - 1];
+      } else {
+        assert(isCompressedLT(lvlTypes[tid][lvl]));
+        curLvlIdx = ADDI(sliceTupleNxStartIdx[tid][lvl - 1],
+                         sliceTupleFwdCnt[0][lvl - 1]);
+        sliceTupleNxStartIdx[tid][lvl] =
+            loadSlicePos(builder, loc, slicePosBuffer[tid][lvl].back(),
+                         curLvlIdx, SlicePosKind::kNext);
+      }
+    }
     if (isDenseLT(lvlTypes[tid][lvl]))
       return true;
+
+    Value sPosBuf = slicePosBuffer[tid][lvl].back();
     // If constraints on the tensor is fully resolved. We do not need to
     // generates slice begin any more, instead we fall back to TACO-based
     // algorithm to (co)iterates over the slice.
-    Value sPosBuf = slicePosBuffer[tid][lvl].back();
-    Value tupleIdx = loadSlicePosPtr(builder, loc, sPosBuf);
+    Value tupleIdx = curLvlIdx;
     posits[tid][lvl] =
         loadSlicePos(builder, loc, sPosBuf, tupleIdx, SlicePosKind::kLo);
     highs[tid][lvl] =
@@ -2134,23 +2087,16 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
   if (sliceInfo.isInitialTensor() ||
       (lvl >= 1 && lvlFullyResolved(tid, lvl - 1))) {
     // First level or previous level has been full resolved.
+    trivialSlice[tid][lvl] = true;
     genResolvedSliceBegin(builder, loc, tid, lvl);
   } else {
     // The previous level has not been full resolved.
+    trivialSlice[tid][lvl] = false;
     genUnResolvedSliceBegin(builder, loc, tid, lvl);
   }
   return false;
 }
 
-void LoopEmitter::invalidateSliceIterIdx(OpBuilder &builder, Location loc,
-                                         TensorId tid, Level lvl) {
-  for (unsigned i = 0; i <= lvl; i++) {
-    if (!isDenseLT(lvlTypes[tid][i]) && !dependentLvlMap[tid][i].empty()) {
-      updateSlicePosPtr(builder, loc, slicePosBuffer[tid][i].back(), C_IDX(0));
-    }
-  }
-}
-
 std::tuple<Value, Value, Value>
 LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
                                    TensorId tid, Level lvl) {
@@ -2175,10 +2121,6 @@ LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
   //   isNonEmpty = false;
   //
   Value absOffset = info.offset;
-  // Resets slices pointers as the resolved slices are invalidated after we
-  // moves forward to the next slice.
-  invalidateSliceIterIdx(builder, loc, tid, lvl);
-
   SmallVector<Value, 3> reduc = {info.minCrd, info.isNonEmpty, absOffset};
   Value sPtrBuf = slicePosBuffer[tid][lvl][info.depth - 1];
   Value fastPathP = CMPI(ugt, info.minCrd, absOffset);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index 5e51cb2110fa1..fa8b0076f733b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -453,11 +453,6 @@ class LoopEmitter {
     return tid < lvlTypes.size() && lvl < lvlTypes[tid].size();
   }
 
-  /// Forwards the (conceptual) "tree iterator" when iterating over a fully
-  /// reduced slice created by index-reduction.
-  void forwardsReducedSliceLevelTreeIt(OpBuilder &builder, Location loc,
-                                       TensorId tid, Level lvl, Value fcnt);
-
   /// Prepares loop for iterating over `tensor[lvl]`, under the assumption
   /// that `tensor[0...lvl-1]` loops have already been set up.
   void prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
@@ -610,11 +605,6 @@ class LoopEmitter {
   void genUnResolvedSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
                                Level lvl);
 
-  /// Invalidates the index kept in slice postion buffers (by setting it to
-  /// zero).
-  /// TODO: We should instead use an SSA value for the index.
-  void invalidateSliceIterIdx(OpBuilder &builder, Location loc, TensorId tid,
-                              Level lvl);
   /// Generates code to get the first non-empty slice of tid on lvl.
   /// return true if has already been resolved.
   bool genSliceBegin(OpBuilder &builder, Location loc, TensorId tid, Level lvl);
@@ -683,6 +673,9 @@ class LoopEmitter {
   // But they always starts with the first pidx pointing to coord > slice.offset
   // to avoid iteration from the beginning.
   std::vector<std::vector<std::vector<Value>>> slicePosBuffer;
+  std::vector<std::vector<Value>> sliceTupleNxStartIdx;
+  std::vector<std::vector<Value>> sliceTupleFwdCnt;
+  std::vector<std::vector<bool>> trivialSlice;
 
   // The (size, stride) for each conceptual slice used for index reduction
   // loops.
diff --git a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
index 02cc5d1e2ef34..a3c1e76a3d09a 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
@@ -12,241 +12,226 @@
 // CHECK-SAME:      %[[VAL_1:.*]]: tensor<3x3xi32>) -> tensor<6x6xi32, #sparse> {
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant true
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant -2 : index
-// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 4 : index
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 2 : index
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 8 : index
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 3 : index
 // CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 1 : index
-// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 5 : index
-// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 2 : index
-// CHECK-DAG:       %[[VAL_10:.*]] = arith.constant 0 : index
-// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant 0 : i32
-// CHECK-DAG:       %[[VAL_12:.*]] = arith.constant false
-// CHECK-DAG:       %[[VAL_13:.*]] = tensor.empty() : tensor<6x6xi32, #sparse>
-// CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_15:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_16:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_17:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_18:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xi32, #sparse> to memref<?xi32>
-// CHECK:           %[[VAL_19:.*]] = memref.alloca() : memref<11xindex>
-// CHECK:           %[[VAL_20:.*]] = memref.alloca() : memref<5xindex>
-// CHECK:           %[[VAL_21:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_7]]] : memref<?xindex>
-// CHECK:           memref.store %[[VAL_10]], %[[VAL_20]]{{\[}}%[[VAL_7]]] : memref<5xindex>
-// CHECK:           memref.store %[[VAL_10]], %[[VAL_20]]{{\[}}%[[VAL_9]]] : memref<5xindex>
-// CHECK:           memref.store %[[VAL_21]], %[[VAL_20]]{{\[}}%[[VAL_6]]] : memref<5xindex>
-// CHECK:           %[[VAL_22:.*]] = arith.cmpi ugt, %[[VAL_21]], %[[VAL_10]] : index
-// CHECK:           %[[VAL_23:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_10]]] : memref<?xindex>
-// CHECK:           %[[VAL_24:.*]] = arith.cmpi uge, %[[VAL_23]], %[[VAL_6]] : index
-// CHECK:           %[[VAL_25:.*]] = arith.andi %[[VAL_22]], %[[VAL_24]] : i1
-// CHECK:           %[[VAL_26:.*]] = arith.addi %[[VAL_23]], %[[VAL_3]] : index
-// CHECK:           %[[VAL_27:.*]] = arith.select %[[VAL_25]], %[[VAL_26]], %[[VAL_10]] : index
-// CHECK:           %[[VAL_28:.*]]:3 = scf.while (%[[VAL_29:.*]] = %[[VAL_22]], %[[VAL_30:.*]] = %[[VAL_23]], %[[VAL_31:.*]] = %[[VAL_27]], %[[VAL_32:.*]] = %[[VAL_13]]) : (i1, index, index, tensor<6x6xi32, #sparse>) -> (index, index, tensor<6x6xi32, #sparse>) {
-// CHECK:             scf.condition(%[[VAL_29]]) %[[VAL_30]], %[[VAL_31]], %[[VAL_32]] : index, index, tensor<6x6xi32, #sparse>
+// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 0 : i32
+// CHECK-DAG:       %[[VAL_10:.*]] = arith.constant false
+// CHECK-DAG:       %[[VAL_11:.*]] = tensor.empty() : tensor<6x6xi32, #sparse>
+// CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_15:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_16:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xi32, #sparse> to memref<?xi32>
+// CHECK-DAG:       %[[VAL_17:.*]] = memref.alloca() : memref<9xindex>
+// CHECK-DAG:       %[[VAL_18:.*]] = memref.alloca() : memref<3xindex>
+// CHECK-DAG:       %[[VAL_19:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_7]]] : memref<?xindex>
+// CHECK:           memref.store %[[VAL_8]], %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex>
+// CHECK:           memref.store %[[VAL_19]], %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<3xindex>
+// CHECK:           %[[VAL_20:.*]] = arith.cmpi ugt, %[[VAL_19]], %[[VAL_8]] : index
+// CHECK:           %[[VAL_21:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_8]]] : memref<?xindex>
+// CHECK:           %[[VAL_22:.*]] = arith.cmpi uge, %[[VAL_21]], %[[VAL_6]] : index
+// CHECK:           %[[VAL_23:.*]] = arith.andi %[[VAL_20]], %[[VAL_22]] : i1
+// CHECK:           %[[VAL_24:.*]] = arith.addi %[[VAL_21]], %[[VAL_3]] : index
+// CHECK:           %[[VAL_25:.*]] = arith.select %[[VAL_23]], %[[VAL_24]], %[[VAL_8]] : index
+// CHECK:           %[[VAL_26:.*]]:3 = scf.while (%[[VAL_27:.*]] = %[[VAL_20]], %[[VAL_28:.*]] = %[[VAL_21]], %[[VAL_29:.*]] = %[[VAL_25]], %[[VAL_30:.*]] = %[[VAL_11]]) : (i1, index, index, tensor<6x6xi32, #sparse>) -> (index, index, tensor<6x6xi32, #sparse>) {
+// CHECK:             scf.condition(%[[VAL_27]]) %[[VAL_28]], %[[VAL_29]], %[[VAL_30]] : index, index, tensor<6x6xi32, #sparse>
 // CHECK:           } do {
-// CHECK:           ^bb0(%[[VAL_33:.*]]: index, %[[VAL_34:.*]]: index, %[[VAL_35:.*]]: tensor<6x6xi32, #sparse>):
...
[truncated]

@PeimingLiu PeimingLiu merged commit baa192e into llvm:main Dec 8, 2023
@PeimingLiu PeimingLiu deleted the ssa branch December 8, 2023 17:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants