Skip to content

[mlir][sparse] fix error when convolution stride is applied on a dens… #79521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 26, 2024

Conversation

PeimingLiu
Copy link
Member

…e level.

@llvmbot
Copy link
Member

llvmbot commented Jan 25, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sparse

Author: Peiming Liu (PeimingLiu)

Changes

…e level.


Full diff: https://github.com/llvm/llvm-project/pull/79521.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp (+2-2)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp (+6-3)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h (+1-1)
  • (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_strided_conv_2d_nhwc_hwcf.mlir (+21)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index 6d832fdc0c2201e..3fa4004ae460efc 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -313,8 +313,8 @@ void LoopEmitter::initSubSectIterator(OpBuilder &builder, Location loc) {
           Value loopHi = loopHighs[loop];
           size = ADDI(size, MULI(loopHi, C_IDX(stride)));
         }
-        it = makeNonEmptySubSectIterator(builder, loc, parent, std::move(lvlIt),
-                                         size, curDep.second);
+        it = makeNonEmptySubSectIterator(builder, loc, parent, loopHighs[loop],
+                                         std::move(lvlIt), size, curDep.second);
       } else {
         Value size = loopHighs[loop];
         const SparseIterator &subSectIter = *iters[t][lvl].back();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index 226cccbc422b9b6..e43896942d7fe68 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -1271,7 +1271,7 @@ static const IterType *unwrapFilter(const SparseIterator *it) {
 }
 
 std::unique_ptr<SparseIterator> sparse_tensor::makeNonEmptySubSectIterator(
-    OpBuilder &b, Location l, const SparseIterator *parent,
+    OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound,
     std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride) {
 
   // Try unwrap the NonEmptySubSectIterator from a filter parent.
@@ -1279,9 +1279,12 @@ std::unique_ptr<SparseIterator> sparse_tensor::makeNonEmptySubSectIterator(
   auto it = std::make_unique<NonEmptySubSectIterator>(
       b, l, parent, std::move(delegate), size);
 
-  if (stride != 1)
+  if (stride != 1) {
+    // TODO: We can safely skip bound checking on sparse levels, but for dense
+    // iteration space, we need the bound to infer the dense loop range.
     return std::make_unique<FilterIterator>(std::move(it), /*offset=*/C_IDX(0),
-                                            C_IDX(stride), /*size=*/C_IDX(-1));
+                                            C_IDX(stride), /*size=*/loopBound);
+  }
   return it;
 }
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
index 08f7c6a747eb57f..d2b3396b72836c5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
@@ -246,7 +246,7 @@ makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit, Value offset,
 /// Helper function to create a SparseIterator object that iterate over the
 /// non-empty subsections set.
 std::unique_ptr<SparseIterator> makeNonEmptySubSectIterator(
-    OpBuilder &b, Location l, const SparseIterator *parent,
+    OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound,
     std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride);
 
 /// Helper function to create a SparseIterator object that iterate over a
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_strided_conv_2d_nhwc_hwcf.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_strided_conv_2d_nhwc_hwcf.mlir
index 98adc26f06e56ba..8ee80045afc760b 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_strided_conv_2d_nhwc_hwcf.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_strided_conv_2d_nhwc_hwcf.mlir
@@ -69,6 +69,14 @@ func.func @conv_2d_nhwc_hwcf_CDCC(%arg0: tensor<?x?x?x?xf32, #CDCC>, %arg1: tens
   return %ret : tensor<?x?x?x?xf32>
 }
 
+func.func @conv_2d_nhwc_hwcf_dual_CDCC(%arg0: tensor<?x?x?x?xf32, #CDCC>, %arg1: tensor<?x?x?x?xf32, #CDCC>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %ret = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
+                                     strides = dense<2> : tensor<2xi64>}
+     ins (%arg0, %arg1: tensor<?x?x?x?xf32, #CDCC>, tensor<?x?x?x?xf32, #CDCC>)
+    outs (%arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %ret : tensor<?x?x?x?xf32>
+}
+
 
 func.func @entry() {
   %c0 = arith.constant 0 : index
@@ -87,6 +95,8 @@ func.func @entry() {
 
   %in2D_nhwc_CCCC = sparse_tensor.convert %in2D_nhwc
     : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32, #CCCC>
+  %filter2D_nhwc_CDCC = sparse_tensor.convert %filter2D_nhwc
+    : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32, #CDCC>
   %in2D_nhwc_CDCC = sparse_tensor.convert %in2D_nhwc
     : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32, #CDCC>
 
@@ -94,9 +104,19 @@ func.func @entry() {
   %CCCC_ret = call @conv_2d_nhwc_hwcf_CCCC(%in2D_nhwc_CCCC, %filter2D_nhwc, %out2D_nhwc) : (tensor<?x?x?x?xf32, #CCCC>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> (tensor<?x?x?x?xf32>)
   %CDCC_ret = call @conv_2d_nhwc_hwcf_CDCC(%in2D_nhwc_CDCC, %filter2D_nhwc, %out2D_nhwc) : (tensor<?x?x?x?xf32, #CDCC>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> (tensor<?x?x?x?xf32>)
 
+  %dual_CDCC_ret = call @conv_2d_nhwc_hwcf_dual_CDCC(%in2D_nhwc_CDCC, %filter2D_nhwc_CDCC, %out2D_nhwc)
+    : (tensor<?x?x?x?xf32, #CDCC>, tensor<?x?x?x?xf32, #CDCC>, tensor<?x?x?x?xf32>) -> (tensor<?x?x?x?xf32>)
+
   // CHECK:      ( ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 20 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ),
   // CHECK-SAME:   ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ),
   // CHECK-SAME:   ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ) )
+  %v_dual = vector.transfer_read %dual_CDCC_ret[%c0, %c0, %c0, %c0], %zero
+      : tensor<?x?x?x?xf32>, vector<3x3x3x1xf32>
+  vector.print %v_dual : vector<3x3x3x1xf32>
+
+  // CHECK-NEXT: ( ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 20 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ),
+  // CHECK-SAME:   ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ),
+  // CHECK-SAME:   ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ) )
   %dense_v = vector.transfer_read %dense_ret[%c0, %c0, %c0, %c0], %zero
       : tensor<?x?x?x?xf32>, vector<3x3x3x1xf32>
   vector.print %dense_v : vector<3x3x3x1xf32>
@@ -120,6 +140,7 @@ func.func @entry() {
   bufferization.dealloc_tensor %filter2D_nhwc : tensor<?x?x?x?xf32>
   bufferization.dealloc_tensor %out2D_nhwc : tensor<?x?x?x?xf32>
 
+  bufferization.dealloc_tensor %filter2D_nhwc_CDCC : tensor<?x?x?x?xf32, #CDCC>
   bufferization.dealloc_tensor %in2D_nhwc_CCCC : tensor<?x?x?x?xf32, #CCCC>
   bufferization.dealloc_tensor %in2D_nhwc_CDCC : tensor<?x?x?x?xf32, #CDCC>
   return

@PeimingLiu PeimingLiu merged commit de5e4d7 into llvm:main Jan 26, 2024
@PeimingLiu PeimingLiu deleted the fix-dual-conv branch January 26, 2024 01:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants