Skip to content

[mlir][sparse] support tensor.pad on CSR tensors #90687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 1, 2024

Conversation

PeimingLiu
Copy link
Member

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Apr 30, 2024

@llvm/pr-subscribers-mlir

Author: Peiming Liu (PeimingLiu)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/90687.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp (+97-35)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h (+3-3)
  • (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir (+11-19)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index caf55072ce32e6..112b9f6c252786 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -95,8 +95,9 @@ class DenseLevel : public SparseTensorLevel {
   ValueRange getLvlBuffers() const override { return {}; }
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
-                        ValueRange parentPos) const override {
+                        ValueRange parentPos, Value inPadZone) const override {
     assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
+    assert(!inPadZone && "Not implememnted");
     Value p = parentPos.front();
     Value posLo = MULI(p, lvlSize);
     return {posLo, lvlSize};
@@ -115,7 +116,8 @@ class BatchLevel : public SparseTensorLevel {
   ValueRange getLvlBuffers() const override { return {}; }
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange,
-                        ValueRange parentPos) const override {
+                        ValueRange parentPos, Value inPadZone) const override {
+    assert(!inPadZone && "Not implememnted");
     assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
     // No need to linearize the position for non-annotated tensors.
     return {C_IDX(0), lvlSize};
@@ -129,18 +131,41 @@ class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
       : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
-                        ValueRange parentPos) const override {
+                        ValueRange parentPos, Value inPadZone) const override {
 
     assert(parentPos.size() == 1 &&
            "compressed level must be the first non-unique level.");
-    Value p = parentPos.front();
 
-    SmallVector<Value> memCrd(batchPrefix);
-    memCrd.push_back(p);
-    Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
-    memCrd.back() = ADDI(p, C_IDX(1));
-    Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
-    return {pLo, pHi};
+    auto loadRange = [&b, l, parentPos, batchPrefix, this]() -> ValuePair {
+      Value p = parentPos.front();
+      SmallVector<Value> memCrd(batchPrefix);
+      memCrd.push_back(p);
+      Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
+      memCrd.back() = ADDI(p, C_IDX(1));
+      Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
+      return {pLo, pHi};
+    };
+
+    if (inPadZone == nullptr)
+      return loadRange();
+
+    SmallVector<Type, 2> types{b.getIndexType(), b.getIndexType()};
+    scf::IfOp posRangeIf = b.create<scf::IfOp>(l, types, inPadZone, true);
+    // True branch.
+    b.setInsertionPointToStart(posRangeIf.thenBlock());
+    // Returns a "fake" empty range [0, 0) if parent iterator is in pad zone.
+    SmallVector<Value, 2> emptyRange{C_IDX(0), C_IDX(0)};
+    b.create<scf::YieldOp>(l, emptyRange);
+
+    // False branch.
+    b.setInsertionPointToStart(posRangeIf.elseBlock());
+    auto [pLo, pHi] = loadRange();
+    SmallVector<Value, 2> loadedRange{pLo, pHi};
+    b.create<scf::YieldOp>(l, loadedRange);
+
+    b.setInsertionPointAfter(posRangeIf);
+    ValueRange posRange = posRangeIf.getResults();
+    return {posRange.front(), posRange.back()};
   }
 };
 
@@ -151,9 +176,10 @@ class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
       : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
-                        ValueRange parentPos) const override {
+                        ValueRange parentPos, Value inPadZone) const override {
     assert(parentPos.size() == 1 &&
            "loose-compressed level must be the first non-unique level.");
+    assert(!inPadZone && "Not implememnted");
     SmallVector<Value> memCrd(batchPrefix);
     Value p = parentPos.front();
     p = MULI(p, C_IDX(2));
@@ -172,8 +198,9 @@ class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
       : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
-                        ValueRange parentPos) const override {
+                        ValueRange parentPos, Value inPadZone) const override {
     assert(parentPos.size() == 1 || parentPos.size() == 2);
+    assert(!inPadZone && "Not implememnted");
     Value p = parentPos.front();
     Value segHi = parentPos.size() == 2 ? parentPos.back() : nullptr;
 
@@ -191,9 +218,10 @@ class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
       : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
-                        ValueRange parentPos) const override {
+                        ValueRange parentPos, Value inPadZone) const override {
     assert(parentPos.size() == 1 && isUnique() &&
            "n:m level can not be non-unique.");
+    assert(!inPadZone && "Not implememnted");
     // Each n:m blk has exactly n specified elements.
     auto n = getN(lt);
     Value posLo = MULI(parentPos.front(), C_IDX(n));
@@ -325,23 +353,7 @@ class TrivialIterator : public ConcreteIterator {
   };
 
   void genInitImpl(OpBuilder &b, Location l,
-                   const SparseIterator *parent) override {
-
-    if (isBatchIterator() && batchCrds.size() <= stl.lvl)
-      batchCrds.resize(stl.lvl + 1, nullptr);
-
-    Value c0 = C_IDX(0);
-    ValueRange pPos = c0;
-    // If the parent iterator is a batch iterator, we also start from 0 (but
-    // on a different batch).
-    if (parent && !parent->isBatchIterator())
-      pPos = parent->getCurPosition();
-
-    ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
-    std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos);
-    // Seek to the lowest position.
-    seek(posLo);
-  }
+                   const SparseIterator *parent) override;
 
   ValuePair genForCond(OpBuilder &b, Location l) override {
     if (randomAccessible())
@@ -465,8 +477,9 @@ class DedupIterator : public ConcreteIterator {
 // A util base-iterator that delegates all methods to the wrapped iterator.
 class SimpleWrapIterator : public SparseIterator {
 public:
-  SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind)
-      : SparseIterator(kind, *wrap), wrap(std::move(wrap)) {}
+  SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind,
+                     unsigned extraCursorVal = 0)
+      : SparseIterator(kind, *wrap, extraCursorVal), wrap(std::move(wrap)) {}
 
   SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
     return wrap->getCursorValTypes(b);
@@ -586,9 +599,10 @@ class PadIterator : public SimpleWrapIterator {
 public:
   PadIterator(std::unique_ptr<SparseIterator> &&wrap, Value padLow,
               Value padHigh)
-      : SimpleWrapIterator(std::move(wrap), IterKind::kPad), padLow(padLow),
-        padHigh(padHigh) {
-    assert(!randomAccessible() && "Not implemented.");
+      : SimpleWrapIterator(std::move(wrap), IterKind::kPad,
+                           wrap->randomAccessible() ? 1 : 0),
+        padLow(padLow), padHigh(padHigh) {
+    // assert(!randomAccessible());
   }
 
   // For LLVM-style RTTI.
@@ -600,6 +614,19 @@ class PadIterator : public SimpleWrapIterator {
     return std::string("pad<") + wrap->getDebugInterfacePrefix() + ">";
   }
 
+  // For padded dense iterator, we append a `inPadZone: bool` in addition to
+  // values used by the wrapped iterator.
+  ValueRange getCurPosition() const override { return getCursor(); }
+
+  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
+    SmallVector<Type> ret = wrap->getCursorValTypes(b);
+    // Need a extra boolean value `inPadZone` for padded dense iterator.
+    if (randomAccessible())
+      ret.push_back(b.getI1Type());
+
+    return ret;
+  }
+
   // The upper bound after padding becomes `size + padLow + padHigh`.
   Value upperBound(OpBuilder &b, Location l) const override {
     return ADDI(ADDI(wrap->upperBound(b, l), padLow), padHigh);
@@ -613,6 +640,14 @@ class PadIterator : public SimpleWrapIterator {
 
   void locateImpl(OpBuilder &b, Location l, Value crd) override {
     assert(randomAccessible());
+    wrap->locate(b, l, SUBI(crd, padLow));
+
+    // inPadZone = crd < padLow || crd >= size + padLow.
+    Value inPadLow = CMPI(ult, crd, padLow);
+    Value inPadHigh = CMPI(uge, crd, ADDI(wrap->upperBound(b, l), padLow));
+    getMutCursorVals().back() = ORI(inPadLow, inPadHigh);
+
+    updateCrd(crd);
   }
 
   Value padLow, padHigh;
@@ -1227,6 +1262,33 @@ ValueRange NonEmptySubSectIterator::inflateSubSectTree(
   return p->inflateSubSectTree(b, l, reduc, visitDenseSubSect);
 }
 
+void TrivialIterator::genInitImpl(OpBuilder &b, Location l,
+                                  const SparseIterator *parent) {
+
+  if (isBatchIterator() && batchCrds.size() <= stl.lvl)
+    batchCrds.resize(stl.lvl + 1, nullptr);
+
+  Value c0 = C_IDX(0);
+  ValueRange pPos = c0;
+  Value inPadZone = nullptr;
+  // If the parent iterator is a batch iterator, we also start from 0 (but
+  // on a different batch).
+  if (parent && !parent->isBatchIterator()) {
+    pPos = parent->getCurPosition();
+    if (llvm::isa<PadIterator>(parent) && parent->randomAccessible()) {
+      // A padded dense iterator create "sparse" padded zone, which need to be
+      // handled specially.
+      inPadZone = pPos.back();
+      pPos = pPos.drop_back();
+    }
+  }
+
+  ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
+  std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos, inPadZone);
+  // Seek to the lowest position.
+  seek(posLo);
+}
+
 void NonEmptySubSectIterator::genInitImpl(OpBuilder &b, Location l,
                                           const SparseIterator *) {
   Value c0 = C_IDX(0);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 2e7eeb2a05f998..120a806536f190 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -46,9 +46,9 @@ class SparseTensorLevel {
   ///
   /// For a sparse level, [posLo, loopHi) specifies the range of index pointer
   /// to load coordinate from the coordinate buffer.
-  virtual std::pair<Value, Value> peekRangeAt(OpBuilder &b, Location l,
-                                              ValueRange batchPrefix,
-                                              ValueRange parentPos) const = 0;
+  virtual std::pair<Value, Value>
+  peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
+              ValueRange parentPos, Value inPadZone = nullptr) const = 0;
 
   Level getLevel() const { return lvl; }
   LevelType getLT() const { return lt; }
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir
index 92fbbf54558237..50dd989416e2a0 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir
@@ -30,16 +30,8 @@
 // Do the same run, but now with direct IR generation and VLA vectorization.
 // RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
 
-#CCCC = #sparse_tensor.encoding<{
-  map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : compressed)
-}>
-
-#CDCD = #sparse_tensor.encoding<{
-  map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : dense, d2 : compressed, d3 : dense)
-}>
-
-#DCCD = #sparse_tensor.encoding<{
-  map = (d0, d1, d2, d3) -> (d0 : dense, d1 : compressed, d2 : compressed, d3 : dense)
+#CDCC_NHWC = #sparse_tensor.encoding<{
+  map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : dense, d2 : compressed, d3 : compressed)
 }>
 
 // Creates and returns 4-D buffer of size (%s1, %s2, %s3, %s4) filled with the value %f
@@ -66,7 +58,7 @@ func.func @conv_2d_nhwc_hwcf(%arg0: tensor<3x8x8x3xf32>, %arg1: tensor<5x5x3x1xf
   return %ret : tensor<3x8x8x1xf32>
 }
 
-func.func @conv_2d_nhwc_hwcf_CCCC(%arg0: tensor<3x8x8x3xf32, #CCCC>, %arg1: tensor<5x5x3x1xf32>) -> tensor<3x8x8x1xf32> {
+func.func @conv_2d_nhwc_hwcf_CDCC_NHWC(%arg0: tensor<3x8x8x3xf32, #CDCC_NHWC>, %arg1: tensor<5x5x3x1xf32>) -> tensor<3x8x8x1xf32> {
   %cst_0 = arith.constant 0.00000e+00 : f32
   %buf = tensor.empty() : tensor<3x8x8x1xf32>
   %s = linalg.fill ins(%cst_0 : f32) outs(%buf : tensor<3x8x8x1xf32>) -> tensor<3x8x8x1xf32>
@@ -74,11 +66,11 @@ func.func @conv_2d_nhwc_hwcf_CCCC(%arg0: tensor<3x8x8x3xf32, #CCCC>, %arg1: tens
   %padded = tensor.pad %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0] {
    ^bb0(%arg75: index, %arg76: index, %arg77: index, %arg78: index):
      tensor.yield %cst_0 : f32
-   } : tensor<3x8x8x3xf32, #CCCC> to tensor<3x12x12x3xf32, #CCCC>
+   } : tensor<3x8x8x3xf32, #CDCC_NHWC> to tensor<3x12x12x3xf32, #CDCC_NHWC>
 
   %ret = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
                                      strides = dense<1> : tensor<2xi64>}
-     ins (%padded, %arg1: tensor<3x12x12x3xf32, #CCCC>, tensor<5x5x3x1xf32>)
+     ins (%padded, %arg1: tensor<3x12x12x3xf32, #CDCC_NHWC>, tensor<5x5x3x1xf32>)
     outs (%s: tensor<3x8x8x1xf32>) -> tensor<3x8x8x1xf32>
   return %ret : tensor<3x8x8x1xf32>
 }
@@ -105,8 +97,8 @@ func.func @main() {
 
   %dense_ret = call @conv_2d_nhwc_hwcf(%static_input, %static_filter, %static_output) : (tensor<3x8x8x3xf32>, tensor<5x5x3x1xf32>, tensor<3x8x8x1xf32>) -> (tensor<3x8x8x1xf32>)
 
-  %in2D_nhwc_CCCC = sparse_tensor.convert %static_input : tensor<3x8x8x3xf32> to tensor<3x8x8x3xf32, #CCCC>
-  %CCCC_ret = call @conv_2d_nhwc_hwcf_CCCC(%in2D_nhwc_CCCC, %static_filter) : (tensor<3x8x8x3xf32, #CCCC>, tensor<5x5x3x1xf32>) -> (tensor<3x8x8x1xf32>)
+  %in2D_nhwc_CDCC_NHWC = sparse_tensor.convert %static_input : tensor<3x8x8x3xf32> to tensor<3x8x8x3xf32, #CDCC_NHWC>
+  %CDCC_NHWC_ret = call @conv_2d_nhwc_hwcf_CDCC_NHWC(%in2D_nhwc_CDCC_NHWC, %static_filter) : (tensor<3x8x8x3xf32, #CDCC_NHWC>, tensor<5x5x3x1xf32>) -> (tensor<3x8x8x1xf32>)
 
 
   // CHECK:   ( ( ( ( 108 ), ( 160 ), ( 196 ), ( 196 ), ( 196 ), ( 196 ), ( 144 ), ( 108 ) ),
@@ -161,17 +153,17 @@ func.func @main() {
   // CHECK-SAME:     ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
   // CHECK-SAME:     ( ( 144 ), ( 192 ), ( 240 ), ( 240 ), ( 240 ), ( 240 ), ( 192 ), ( 144 ) ),
   // CHECK-SAME:     ( ( 108 ), ( 144 ), ( 180 ), ( 180 ), ( 180 ), ( 180 ), ( 144 ), ( 108 ) ) ) )
-  %CCCC_v = vector.transfer_read %CCCC_ret[%c0, %c0, %c0, %c0], %zero
+  %CDCC_NHWC_v = vector.transfer_read %CDCC_NHWC_ret[%c0, %c0, %c0, %c0], %zero
       : tensor<3x8x8x1xf32>, vector<3x8x8x1xf32>
-  vector.print %CCCC_v : vector<3x8x8x1xf32>
+  vector.print %CDCC_NHWC_v : vector<3x8x8x1xf32>
 
   bufferization.dealloc_tensor %static_filter : tensor<5x5x3x1xf32>
   bufferization.dealloc_tensor %static_input  : tensor<3x8x8x3xf32>
   bufferization.dealloc_tensor %static_output : tensor<3x8x8x1xf32>
 
-  bufferization.dealloc_tensor %CCCC_ret : tensor<3x8x8x1xf32>
+  bufferization.dealloc_tensor %CDCC_NHWC_ret : tensor<3x8x8x1xf32>
 
-  bufferization.dealloc_tensor %in2D_nhwc_CCCC : tensor<3x8x8x3xf32, #CCCC>
+  bufferization.dealloc_tensor %in2D_nhwc_CDCC_NHWC : tensor<3x8x8x3xf32, #CDCC_NHWC>
 
   return
 }

@llvmbot
Copy link
Member

llvmbot commented Apr 30, 2024

@llvm/pr-subscribers-mlir-sparse

Author: Peiming Liu (PeimingLiu)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/90687.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp (+97-35)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h (+3-3)
  • (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir (+11-19)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index caf55072ce32e6..112b9f6c252786 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -95,8 +95,9 @@ class DenseLevel : public SparseTensorLevel {
   ValueRange getLvlBuffers() const override { return {}; }
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
-                        ValueRange parentPos) const override {
+                        ValueRange parentPos, Value inPadZone) const override {
     assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
+    assert(!inPadZone && "Not implememnted");
     Value p = parentPos.front();
     Value posLo = MULI(p, lvlSize);
     return {posLo, lvlSize};
@@ -115,7 +116,8 @@ class BatchLevel : public SparseTensorLevel {
   ValueRange getLvlBuffers() const override { return {}; }
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange,
-                        ValueRange parentPos) const override {
+                        ValueRange parentPos, Value inPadZone) const override {
+    assert(!inPadZone && "Not implememnted");
     assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
     // No need to linearize the position for non-annotated tensors.
     return {C_IDX(0), lvlSize};
@@ -129,18 +131,41 @@ class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
       : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
-                        ValueRange parentPos) const override {
+                        ValueRange parentPos, Value inPadZone) const override {
 
     assert(parentPos.size() == 1 &&
            "compressed level must be the first non-unique level.");
-    Value p = parentPos.front();
 
-    SmallVector<Value> memCrd(batchPrefix);
-    memCrd.push_back(p);
-    Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
-    memCrd.back() = ADDI(p, C_IDX(1));
-    Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
-    return {pLo, pHi};
+    auto loadRange = [&b, l, parentPos, batchPrefix, this]() -> ValuePair {
+      Value p = parentPos.front();
+      SmallVector<Value> memCrd(batchPrefix);
+      memCrd.push_back(p);
+      Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
+      memCrd.back() = ADDI(p, C_IDX(1));
+      Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
+      return {pLo, pHi};
+    };
+
+    if (inPadZone == nullptr)
+      return loadRange();
+
+    SmallVector<Type, 2> types{b.getIndexType(), b.getIndexType()};
+    scf::IfOp posRangeIf = b.create<scf::IfOp>(l, types, inPadZone, true);
+    // True branch.
+    b.setInsertionPointToStart(posRangeIf.thenBlock());
+    // Returns a "fake" empty range [0, 0) if parent iterator is in pad zone.
+    SmallVector<Value, 2> emptyRange{C_IDX(0), C_IDX(0)};
+    b.create<scf::YieldOp>(l, emptyRange);
+
+    // False branch.
+    b.setInsertionPointToStart(posRangeIf.elseBlock());
+    auto [pLo, pHi] = loadRange();
+    SmallVector<Value, 2> loadedRange{pLo, pHi};
+    b.create<scf::YieldOp>(l, loadedRange);
+
+    b.setInsertionPointAfter(posRangeIf);
+    ValueRange posRange = posRangeIf.getResults();
+    return {posRange.front(), posRange.back()};
   }
 };
 
@@ -151,9 +176,10 @@ class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
       : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
-                        ValueRange parentPos) const override {
+                        ValueRange parentPos, Value inPadZone) const override {
     assert(parentPos.size() == 1 &&
            "loose-compressed level must be the first non-unique level.");
+    assert(!inPadZone && "Not implememnted");
     SmallVector<Value> memCrd(batchPrefix);
     Value p = parentPos.front();
     p = MULI(p, C_IDX(2));
@@ -172,8 +198,9 @@ class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
       : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
-                        ValueRange parentPos) const override {
+                        ValueRange parentPos, Value inPadZone) const override {
     assert(parentPos.size() == 1 || parentPos.size() == 2);
+    assert(!inPadZone && "Not implememnted");
     Value p = parentPos.front();
     Value segHi = parentPos.size() == 2 ? parentPos.back() : nullptr;
 
@@ -191,9 +218,10 @@ class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
       : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
-                        ValueRange parentPos) const override {
+                        ValueRange parentPos, Value inPadZone) const override {
     assert(parentPos.size() == 1 && isUnique() &&
            "n:m level can not be non-unique.");
+    assert(!inPadZone && "Not implememnted");
     // Each n:m blk has exactly n specified elements.
     auto n = getN(lt);
     Value posLo = MULI(parentPos.front(), C_IDX(n));
@@ -325,23 +353,7 @@ class TrivialIterator : public ConcreteIterator {
   };
 
   void genInitImpl(OpBuilder &b, Location l,
-                   const SparseIterator *parent) override {
-
-    if (isBatchIterator() && batchCrds.size() <= stl.lvl)
-      batchCrds.resize(stl.lvl + 1, nullptr);
-
-    Value c0 = C_IDX(0);
-    ValueRange pPos = c0;
-    // If the parent iterator is a batch iterator, we also start from 0 (but
-    // on a different batch).
-    if (parent && !parent->isBatchIterator())
-      pPos = parent->getCurPosition();
-
-    ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
-    std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos);
-    // Seek to the lowest position.
-    seek(posLo);
-  }
+                   const SparseIterator *parent) override;
 
   ValuePair genForCond(OpBuilder &b, Location l) override {
     if (randomAccessible())
@@ -465,8 +477,9 @@ class DedupIterator : public ConcreteIterator {
 // A util base-iterator that delegates all methods to the wrapped iterator.
 class SimpleWrapIterator : public SparseIterator {
 public:
-  SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind)
-      : SparseIterator(kind, *wrap), wrap(std::move(wrap)) {}
+  SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind,
+                     unsigned extraCursorVal = 0)
+      : SparseIterator(kind, *wrap, extraCursorVal), wrap(std::move(wrap)) {}
 
   SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
     return wrap->getCursorValTypes(b);
@@ -586,9 +599,10 @@ class PadIterator : public SimpleWrapIterator {
 public:
   PadIterator(std::unique_ptr<SparseIterator> &&wrap, Value padLow,
               Value padHigh)
-      : SimpleWrapIterator(std::move(wrap), IterKind::kPad), padLow(padLow),
-        padHigh(padHigh) {
-    assert(!randomAccessible() && "Not implemented.");
+      : SimpleWrapIterator(std::move(wrap), IterKind::kPad,
+                           wrap->randomAccessible() ? 1 : 0),
+        padLow(padLow), padHigh(padHigh) {
+    // assert(!randomAccessible());
   }
 
   // For LLVM-style RTTI.
@@ -600,6 +614,19 @@ class PadIterator : public SimpleWrapIterator {
     return std::string("pad<") + wrap->getDebugInterfacePrefix() + ">";
   }
 
+  // For padded dense iterator, we append a `inPadZone: bool` in addition to
+  // values used by the wrapped iterator.
+  ValueRange getCurPosition() const override { return getCursor(); }
+
+  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
+    SmallVector<Type> ret = wrap->getCursorValTypes(b);
+    // Need a extra boolean value `inPadZone` for padded dense iterator.
+    if (randomAccessible())
+      ret.push_back(b.getI1Type());
+
+    return ret;
+  }
+
   // The upper bound after padding becomes `size + padLow + padHigh`.
   Value upperBound(OpBuilder &b, Location l) const override {
     return ADDI(ADDI(wrap->upperBound(b, l), padLow), padHigh);
@@ -613,6 +640,14 @@ class PadIterator : public SimpleWrapIterator {
 
   void locateImpl(OpBuilder &b, Location l, Value crd) override {
     assert(randomAccessible());
+    wrap->locate(b, l, SUBI(crd, padLow));
+
+    // inPadZone = crd < padLow || crd >= size + padLow.
+    Value inPadLow = CMPI(ult, crd, padLow);
+    Value inPadHigh = CMPI(uge, crd, ADDI(wrap->upperBound(b, l), padLow));
+    getMutCursorVals().back() = ORI(inPadLow, inPadHigh);
+
+    updateCrd(crd);
   }
 
   Value padLow, padHigh;
@@ -1227,6 +1262,33 @@ ValueRange NonEmptySubSectIterator::inflateSubSectTree(
   return p->inflateSubSectTree(b, l, reduc, visitDenseSubSect);
 }
 
+void TrivialIterator::genInitImpl(OpBuilder &b, Location l,
+                                  const SparseIterator *parent) {
+
+  if (isBatchIterator() && batchCrds.size() <= stl.lvl)
+    batchCrds.resize(stl.lvl + 1, nullptr);
+
+  Value c0 = C_IDX(0);
+  ValueRange pPos = c0;
+  Value inPadZone = nullptr;
+  // If the parent iterator is a batch iterator, we also start from 0 (but
+  // on a different batch).
+  if (parent && !parent->isBatchIterator()) {
+    pPos = parent->getCurPosition();
+    if (llvm::isa<PadIterator>(parent) && parent->randomAccessible()) {
+      // A padded dense iterator create "sparse" padded zone, which need to be
+      // handled specially.
+      inPadZone = pPos.back();
+      pPos = pPos.drop_back();
+    }
+  }
+
+  ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
+  std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos, inPadZone);
+  // Seek to the lowest position.
+  seek(posLo);
+}
+
 void NonEmptySubSectIterator::genInitImpl(OpBuilder &b, Location l,
                                           const SparseIterator *) {
   Value c0 = C_IDX(0);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 2e7eeb2a05f998..120a806536f190 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -46,9 +46,9 @@ class SparseTensorLevel {
   ///
   /// For a sparse level, [posLo, loopHi) specifies the range of index pointer
   /// to load coordinate from the coordinate buffer.
-  virtual std::pair<Value, Value> peekRangeAt(OpBuilder &b, Location l,
-                                              ValueRange batchPrefix,
-                                              ValueRange parentPos) const = 0;
+  virtual std::pair<Value, Value>
+  peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
+              ValueRange parentPos, Value inPadZone = nullptr) const = 0;
 
   Level getLevel() const { return lvl; }
   LevelType getLT() const { return lt; }
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir
index 92fbbf54558237..50dd989416e2a0 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir
@@ -30,16 +30,8 @@
 // Do the same run, but now with direct IR generation and VLA vectorization.
 // RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
 
-#CCCC = #sparse_tensor.encoding<{
-  map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : compressed)
-}>
-
-#CDCD = #sparse_tensor.encoding<{
-  map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : dense, d2 : compressed, d3 : dense)
-}>
-
-#DCCD = #sparse_tensor.encoding<{
-  map = (d0, d1, d2, d3) -> (d0 : dense, d1 : compressed, d2 : compressed, d3 : dense)
+#CDCC_NHWC = #sparse_tensor.encoding<{
+  map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : dense, d2 : compressed, d3 : compressed)
 }>
 
 // Creates and returns 4-D buffer of size (%s1, %s2, %s3, %s4) filled with the value %f
@@ -66,7 +58,7 @@ func.func @conv_2d_nhwc_hwcf(%arg0: tensor<3x8x8x3xf32>, %arg1: tensor<5x5x3x1xf
   return %ret : tensor<3x8x8x1xf32>
 }
 
-func.func @conv_2d_nhwc_hwcf_CCCC(%arg0: tensor<3x8x8x3xf32, #CCCC>, %arg1: tensor<5x5x3x1xf32>) -> tensor<3x8x8x1xf32> {
+func.func @conv_2d_nhwc_hwcf_CDCC_NHWC(%arg0: tensor<3x8x8x3xf32, #CDCC_NHWC>, %arg1: tensor<5x5x3x1xf32>) -> tensor<3x8x8x1xf32> {
   %cst_0 = arith.constant 0.00000e+00 : f32
   %buf = tensor.empty() : tensor<3x8x8x1xf32>
   %s = linalg.fill ins(%cst_0 : f32) outs(%buf : tensor<3x8x8x1xf32>) -> tensor<3x8x8x1xf32>
@@ -74,11 +66,11 @@ func.func @conv_2d_nhwc_hwcf_CCCC(%arg0: tensor<3x8x8x3xf32, #CCCC>, %arg1: tens
   %padded = tensor.pad %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0] {
    ^bb0(%arg75: index, %arg76: index, %arg77: index, %arg78: index):
      tensor.yield %cst_0 : f32
-   } : tensor<3x8x8x3xf32, #CCCC> to tensor<3x12x12x3xf32, #CCCC>
+   } : tensor<3x8x8x3xf32, #CDCC_NHWC> to tensor<3x12x12x3xf32, #CDCC_NHWC>
 
   %ret = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
                                      strides = dense<1> : tensor<2xi64>}
-     ins (%padded, %arg1: tensor<3x12x12x3xf32, #CCCC>, tensor<5x5x3x1xf32>)
+     ins (%padded, %arg1: tensor<3x12x12x3xf32, #CDCC_NHWC>, tensor<5x5x3x1xf32>)
     outs (%s: tensor<3x8x8x1xf32>) -> tensor<3x8x8x1xf32>
   return %ret : tensor<3x8x8x1xf32>
 }
@@ -105,8 +97,8 @@ func.func @main() {
 
   %dense_ret = call @conv_2d_nhwc_hwcf(%static_input, %static_filter, %static_output) : (tensor<3x8x8x3xf32>, tensor<5x5x3x1xf32>, tensor<3x8x8x1xf32>) -> (tensor<3x8x8x1xf32>)
 
-  %in2D_nhwc_CCCC = sparse_tensor.convert %static_input : tensor<3x8x8x3xf32> to tensor<3x8x8x3xf32, #CCCC>
-  %CCCC_ret = call @conv_2d_nhwc_hwcf_CCCC(%in2D_nhwc_CCCC, %static_filter) : (tensor<3x8x8x3xf32, #CCCC>, tensor<5x5x3x1xf32>) -> (tensor<3x8x8x1xf32>)
+  %in2D_nhwc_CDCC_NHWC = sparse_tensor.convert %static_input : tensor<3x8x8x3xf32> to tensor<3x8x8x3xf32, #CDCC_NHWC>
+  %CDCC_NHWC_ret = call @conv_2d_nhwc_hwcf_CDCC_NHWC(%in2D_nhwc_CDCC_NHWC, %static_filter) : (tensor<3x8x8x3xf32, #CDCC_NHWC>, tensor<5x5x3x1xf32>) -> (tensor<3x8x8x1xf32>)
 
 
   // CHECK:   ( ( ( ( 108 ), ( 160 ), ( 196 ), ( 196 ), ( 196 ), ( 196 ), ( 144 ), ( 108 ) ),
@@ -161,17 +153,17 @@ func.func @main() {
   // CHECK-SAME:     ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
   // CHECK-SAME:     ( ( 144 ), ( 192 ), ( 240 ), ( 240 ), ( 240 ), ( 240 ), ( 192 ), ( 144 ) ),
   // CHECK-SAME:     ( ( 108 ), ( 144 ), ( 180 ), ( 180 ), ( 180 ), ( 180 ), ( 144 ), ( 108 ) ) ) )
-  %CCCC_v = vector.transfer_read %CCCC_ret[%c0, %c0, %c0, %c0], %zero
+  %CDCC_NHWC_v = vector.transfer_read %CDCC_NHWC_ret[%c0, %c0, %c0, %c0], %zero
       : tensor<3x8x8x1xf32>, vector<3x8x8x1xf32>
-  vector.print %CCCC_v : vector<3x8x8x1xf32>
+  vector.print %CDCC_NHWC_v : vector<3x8x8x1xf32>
 
   bufferization.dealloc_tensor %static_filter : tensor<5x5x3x1xf32>
   bufferization.dealloc_tensor %static_input  : tensor<3x8x8x3xf32>
   bufferization.dealloc_tensor %static_output : tensor<3x8x8x1xf32>
 
-  bufferization.dealloc_tensor %CCCC_ret : tensor<3x8x8x1xf32>
+  bufferization.dealloc_tensor %CDCC_NHWC_ret : tensor<3x8x8x1xf32>
 
-  bufferization.dealloc_tensor %in2D_nhwc_CCCC : tensor<3x8x8x3xf32, #CCCC>
+  bufferization.dealloc_tensor %in2D_nhwc_CDCC_NHWC : tensor<3x8x8x3xf32, #CDCC_NHWC>
 
   return
 }

@PeimingLiu PeimingLiu force-pushed the handle-pad branch 3 times, most recently from 69f555f to c83ada7 Compare May 1, 2024 20:33
@PeimingLiu PeimingLiu merged commit 7888539 into llvm:main May 1, 2024
@PeimingLiu PeimingLiu deleted the handle-pad branch May 1, 2024 22:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants