Skip to content

[mlir][sparse] add a sparse_tensor.print operation #83321

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 28, 2024
Merged

Conversation

aartbik
Copy link
Contributor

@aartbik aartbik commented Feb 28, 2024

This operation is mainly used for testing and debugging purposes but provides a very convenient way to quickly inspect the contents of a sparse tensor (all components over all stored levels).

Example:

[ [ 1, 0, 2, 0, 0, 0, 0, 0 ],
[ 0, 0, 0, 0, 0, 0, 0, 0 ],
[ 0, 0, 0, 0, 0, 0, 0, 0 ],
[ 0, 0, 3, 4, 0, 5, 0, 0 ]

when stored sparse as DCSC prints as

---- Sparse Tensor ----
nse = 5
pos[0] : ( 0, 4, )
crd[0] : ( 0, 2, 3, 5, )
pos[1] : ( 0, 1, 3, 4, 5, )
crd[1] : ( 0, 0, 3, 3, 3, )
values : ( 1, 2, 3, 4, 5, )

This operation is mainly used for testing and debugging
purposes but provides a very convenient way to quickly
inspect the contents of a sparse tensor (all components
over all stored levels).

Example:

[ [ 1, 0, 2, 0, 0, 0, 0, 0 ],
  [ 0, 0, 0, 0, 0, 0, 0, 0 ],
  [ 0, 0, 0, 0, 0, 0, 0, 0 ],
  [ 0, 0, 3, 4, 0, 5, 0, 0 ]

when stored sparse as DCSC prints as

---- Sparse Tensor ----
nse = 5
pos[0] : ( 0, 4,  )
crd[0] : ( 0, 2, 3, 5,  )
pos[1] : ( 0, 1, 3, 4, 5,  )
crd[1] : ( 0, 0, 3, 3, 3,  )
values : ( 1, 2, 3, 4, 5,  )
----
@llvmbot
Copy link
Member

llvmbot commented Feb 28, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sparse

Author: Aart Bik (aartbik)

Changes

This operation is mainly used for testing and debugging purposes but provides a very convenient way to quickly inspect the contents of a sparse tensor (all components over all stored levels).

Example:

[ [ 1, 0, 2, 0, 0, 0, 0, 0 ],
[ 0, 0, 0, 0, 0, 0, 0, 0 ],
[ 0, 0, 0, 0, 0, 0, 0, 0 ],
[ 0, 0, 3, 4, 0, 5, 0, 0 ]

when stored sparse as DCSC prints as

---- Sparse Tensor ----
nse = 5
pos[0] : ( 0, 4, )
crd[0] : ( 0, 2, 3, 5, )
pos[1] : ( 0, 1, 3, 4, 5, )
crd[1] : ( 0, 0, 3, 3, 3, )
values : ( 1, 2, 3, 4, 5, )


Full diff: https://github.com/llvm/llvm-project/pull/83321.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td (+22)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+94-1)
  • (added) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print.mlir (+215)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 3127cf1b1bcf69..9007e4e98e3163 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1453,4 +1453,26 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Debugging Operations.
+//===----------------------------------------------------------------------===//
+
+def SparseTensor_PrintOp : SparseTensor_Op<"print">,
+    Arguments<(ins AnySparseTensor:$tensor)> {
+  string summary = "Prints a sparse tensor (for testing and debugging)";
+  string description = [{
+    Prints the individual components of a sparse tensors (the positions,
+    coordinates, and values components) to stdout for testing and debugging
+    purposes. This operation lowers to just a few primitives in a light-weight
+    runtime support to simplify supporting this operation on new platforms.
+
+    Example:
+
+    ```mlir
+    sparse_tensor.print %tensor : tensor<1024x1024xf64, #CSR>
+    ```
+  }];
+  let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
+}
+
 #endif // SPARSETENSOR_OPS
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 1bcc131781d34d..fa078e2c2efd75 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -21,9 +21,11 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/Support/LLVM.h"
@@ -598,6 +600,96 @@ struct GenSemiRingReduction : public OpRewritePattern<GenericOp> {
   }
 };
 
+/// Sparse rewriting rule for the print operator. This operation is mainly used
+/// for debugging and testing. As such, it lowers to the vector.print operation
+/// which only require very light-weight runtime support.
+struct PrintRewriter : public OpRewritePattern<PrintOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(PrintOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto tensor = op.getTensor();
+    auto stt = getSparseTensorType(tensor);
+    // Header with NSE.
+    auto nse = rewriter.create<NumberOfEntriesOp>(loc, tensor);
+    rewriter.create<vector::PrintOp>(
+        loc, rewriter.getStringAttr("---- Sparse Tensor ----\nnse = "));
+    rewriter.create<vector::PrintOp>(loc, nse);
+    // Use the "codegen" foreach loop construct to iterate over
+    // all typical sparse tensor components for printing.
+    foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc,
+                                            &tensor](Type tp, FieldIndex,
+                                                     SparseTensorFieldKind kind,
+                                                     Level l, LevelType) {
+      switch (kind) {
+      case SparseTensorFieldKind::StorageSpec: {
+        break;
+      }
+      case SparseTensorFieldKind::PosMemRef: {
+        auto lvl = constantIndex(rewriter, loc, l);
+        rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("pos["));
+        rewriter.create<vector::PrintOp>(
+            loc, lvl, vector::PrintPunctuation::NoPunctuation);
+        rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
+        auto poss = rewriter.create<ToPositionsOp>(loc, tp, tensor, l);
+        printContents(rewriter, loc, tp, poss);
+        break;
+      }
+      case SparseTensorFieldKind::CrdMemRef: {
+        auto lvl = constantIndex(rewriter, loc, l);
+        rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("crd["));
+        rewriter.create<vector::PrintOp>(
+            loc, lvl, vector::PrintPunctuation::NoPunctuation);
+        rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
+        auto crds = rewriter.create<ToCoordinatesOp>(loc, tp, tensor, l);
+        printContents(rewriter, loc, tp, crds);
+        break;
+      }
+      case SparseTensorFieldKind::ValMemRef: {
+        rewriter.create<vector::PrintOp>(loc,
+                                         rewriter.getStringAttr("values : "));
+        auto vals = rewriter.create<ToValuesOp>(loc, tp, tensor);
+        printContents(rewriter, loc, tp, vals);
+        break;
+      }
+      }
+      return true;
+    });
+    rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("----\n"));
+    rewriter.eraseOp(op);
+    return success();
+  }
+
+private:
+  // Helper to print contents of a single memref. Note that for the "push_back"
+  // vectors, this prints the full capacity, not just the size. This is done
+  // on purpose, so that clients see how much storage has been allocated in
+  // total. Contents of the extra capacity in the buffer may be uninitialized
+  // (unless the flag enable-buffer-initialization is set to true).
+  //
+  // Generates code to print:
+  //    ( a0, a1, ... )
+  static void printContents(PatternRewriter &rewriter, Location loc, Type tp,
+                            Value vec) {
+    // Open bracket.
+    rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
+    // For loop over elements.
+    auto zero = constantIndex(rewriter, loc, 0);
+    auto size = rewriter.create<memref::DimOp>(loc, vec, zero);
+    auto step = constantIndex(rewriter, loc, 1);
+    auto forOp = rewriter.create<scf::ForOp>(loc, zero, size, step);
+    rewriter.setInsertionPointToStart(forOp.getBody());
+    auto idx = forOp.getInductionVar();
+    auto val = rewriter.create<memref::LoadOp>(loc, vec, idx);
+    rewriter.create<vector::PrintOp>(loc, val, vector::PrintPunctuation::Comma);
+    rewriter.setInsertionPointAfter(forOp);
+    // Close bracket and end of line.
+    rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
+    rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
+  }
+};
+
 /// Sparse rewriting rule for sparse-to-sparse reshape operator.
 struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
 public:
@@ -1284,7 +1376,8 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
 
 void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
   patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
-               GenSemiRingReduction, GenSemiRingSelect>(patterns.getContext());
+               GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
+      patterns.getContext());
 }
 
 void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print.mlir
new file mode 100755
index 00000000000000..797a04bb9ff862
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print.mlir
@@ -0,0 +1,215 @@
+//--------------------------------------------------------------------------------------------------
+// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS.
+//
+// Set-up that's shared across all tests in this directory. In principle, this
+// config could be moved to lit.local.cfg. However, there are downstream users that
+//  do not use these LIT config files. Hence why this is kept inline.
+//
+// DEFINE: %{sparsifier_opts} = enable-runtime-library=true
+// DEFINE: %{sparsifier_opts_sve} = enable-arm-sve=true %{sparsifier_opts}
+// DEFINE: %{compile} = mlir-opt %s --sparsifier="%{sparsifier_opts}"
+// DEFINE: %{compile_sve} = mlir-opt %s --sparsifier="%{sparsifier_opts_sve}"
+// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
+// DEFINE: %{run_opts} = -e main -entry-point-result=void
+// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
+// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
+//
+// DEFINE: %{env} =
+//--------------------------------------------------------------------------------------------------
+
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+// Do the same run, but now with direct IR generation.
+// REDEFINE: %{sparsifier_opts} = enable-runtime-library=false enable-buffer-initialization=true
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+
+#AllDense = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : dense,
+    j : dense
+  )
+}>
+
+#AllDenseT = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    j : dense,
+    i : dense
+  )
+}>
+
+#CSR = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : dense,
+    j : compressed
+  )
+}>
+
+#DCSR = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed,
+    j : compressed
+  )
+}>
+
+#CSC = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    j : dense,
+    i : compressed
+  )
+}>
+
+#DCSC = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    j : compressed,
+    i : compressed
+  )
+}>
+
+#BSR = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i floordiv 2 : compressed,
+    j floordiv 4 : compressed,
+    i mod 2 : dense,
+    j mod 4 : dense
+  )
+}>
+
+#BSRC = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i floordiv 2 : compressed,
+    j floordiv 4 : compressed,
+    j mod 4 : dense,
+    i mod 2 : dense
+  )
+}>
+
+#BSC = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    j floordiv 4 : compressed,
+    i floordiv 2 : compressed,
+    i mod 2 : dense,
+    j mod 4 : dense
+  )
+}>
+
+#BSCC = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    j floordiv 4 : compressed,
+    i floordiv 2 : compressed,
+    j mod 4 : dense,
+    i mod 2 : dense
+  )
+}>
+
+module {
+
+  //
+  // Main driver that tests sparse tensor storage.
+  //
+  func.func @main() {
+    %x = arith.constant dense <[
+         [ 1, 0, 2, 0, 0, 0, 0, 0 ],
+         [ 0, 0, 0, 0, 0, 0, 0, 0 ],
+         [ 0, 0, 0, 0, 0, 0, 0, 0 ],
+         [ 0, 0, 3, 4, 0, 5, 0, 0 ] ]> : tensor<4x8xi32>
+
+    %a = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #CSR>
+    %b = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #DCSR>
+    %c = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #CSC>
+    %d = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #DCSC>
+    %e = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSR>
+    %f = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSRC>
+    %g = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSC>
+    %h = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSCC>
+
+    //
+    // CHECK:      ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 5
+    // CHECK-NEXT: pos[1] : ( 0, 2, 2, 2, 5,
+    // CHECK-NEXT: crd[1] : ( 0, 2, 2, 3, 5,
+    // CHECK-NEXT: values : ( 1, 2, 3, 4, 5,
+    // CHECK-NEXT: ----
+    sparse_tensor.print %a : tensor<4x8xi32, #CSR>
+
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 5
+    // CHECK-NEXT: pos[0] : ( 0, 2,
+    // CHECK-NEXT: crd[0] : ( 0, 3,
+    // CHECK-NEXT: pos[1] : ( 0, 2, 5,
+    // CHECK-NEXT: crd[1] : ( 0, 2, 2, 3, 5,
+    // CHECK-NEXT: values : ( 1, 2, 3, 4, 5,
+    // CHECK-NEXT: ----
+    sparse_tensor.print %b : tensor<4x8xi32, #DCSR>
+
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 5
+    // CHECK-NEXT: pos[1] : ( 0, 1, 1, 3, 4, 4, 5, 5, 5,
+    // CHECK-NEXT: crd[1] : ( 0, 0, 3, 3, 3,
+    // CHECK-NEXT: values : ( 1, 2, 3, 4, 5,
+    // CHECK-NEXT: ----
+    sparse_tensor.print %c : tensor<4x8xi32, #CSC>
+
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 5
+    // CHECK-NEXT: pos[0] : ( 0, 4,
+    // CHECK-NEXT: crd[0] : ( 0, 2, 3, 5,
+    // CHECK-NEXT: pos[1] : ( 0, 1, 3, 4, 5,
+    // CHECK-NEXT: crd[1] : ( 0, 0, 3, 3, 3,
+    // CHECK-NEXT: values : ( 1, 2, 3, 4, 5,
+    // CHECK-NEXT: ----
+    sparse_tensor.print %d : tensor<4x8xi32, #DCSC>
+
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 24
+    // CHECK-NEXT: pos[0] : ( 0, 2,
+    // CHECK-NEXT: crd[0] : ( 0, 1,
+    // CHECK-NEXT: pos[1] : ( 0, 1, 3,
+    // CHECK-NEXT: crd[1] : ( 0, 0, 1,
+    // CHECK-NEXT: values : ( 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 4, 0, 0, 0, 0, 0, 5, 0, 0,
+    // CHECK-NEXT: ----
+    sparse_tensor.print %e : tensor<4x8xi32, #BSR>
+
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 24
+    // CHECK-NEXT: pos[0] : ( 0, 2,
+    // CHECK-NEXT: crd[0] : ( 0, 1,
+    // CHECK-NEXT: pos[1] : ( 0, 1, 3,
+    // CHECK-NEXT: crd[1] : ( 0, 0, 1,
+    // CHECK-NEXT: values : ( 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 4, 0, 0, 0, 5, 0, 0, 0, 0,
+    // CHECK-NEXT: ----
+    sparse_tensor.print %f : tensor<4x8xi32, #BSRC>
+
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 24
+    // CHECK-NEXT: pos[0] : ( 0, 2,
+    // CHECK-NEXT: crd[0] : ( 0, 1,
+    // CHECK-NEXT: pos[1] : ( 0, 2, 3,
+    // CHECK-NEXT: crd[1] : ( 0, 1, 1,
+    // CHECK-NEXT: values : ( 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 4, 0, 0, 0, 0, 0, 5, 0, 0,
+    // CHECK-NEXT: ----
+    sparse_tensor.print %g : tensor<4x8xi32, #BSC>
+
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 24
+    // CHECK-NEXT: pos[0] : ( 0, 2,
+    // CHECK-NEXT: crd[0] : ( 0, 1,
+    // CHECK-NEXT: pos[1] : ( 0, 2, 3,
+    // CHECK-NEXT: crd[1] : ( 0, 1, 1,
+    // CHECK-NEXT: values : ( 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 4, 0, 0, 0, 5, 0, 0, 0, 0,
+    // CHECK-NEXT: ----
+    sparse_tensor.print %h : tensor<4x8xi32, #BSCC>
+
+    // Release the resources.
+    bufferization.dealloc_tensor %a : tensor<4x8xi32, #CSR>
+    bufferization.dealloc_tensor %b : tensor<4x8xi32, #DCSR>
+    bufferization.dealloc_tensor %c : tensor<4x8xi32, #CSC>
+    bufferization.dealloc_tensor %d : tensor<4x8xi32, #DCSC>
+    bufferization.dealloc_tensor %e : tensor<4x8xi32, #BSR>
+    bufferization.dealloc_tensor %f : tensor<4x8xi32, #BSRC>
+    bufferization.dealloc_tensor %g : tensor<4x8xi32, #BSC>
+    bufferization.dealloc_tensor %h : tensor<4x8xi32, #BSCC>
+
+    return
+  }
+}

@aartbik aartbik merged commit d37affb into llvm:main Feb 28, 2024
@aartbik aartbik deleted the bik branch February 28, 2024 20:33
mylai-mtk pushed a commit to mylai-mtk/llvm-project that referenced this pull request Jul 12, 2024
This operation is mainly used for testing and debugging purposes but
provides a very convenient way to quickly inspect the contents of a
sparse tensor (all components over all stored levels).

Example:

[ [ 1, 0, 2, 0, 0, 0, 0, 0 ],
  [ 0, 0, 0, 0, 0, 0, 0, 0 ],
  [ 0, 0, 0, 0, 0, 0, 0, 0 ],
  [ 0, 0, 3, 4, 0, 5, 0, 0 ]

when stored sparse as DCSC prints as

---- Sparse Tensor ----
nse = 5
pos[0] : ( 0, 4,  )
crd[0] : ( 0, 2, 3, 5,  )
pos[1] : ( 0, 1, 3, 4, 5,  )
crd[1] : ( 0, 0, 3, 3, 3,  )
values : ( 1, 2, 3, 4, 5,  )
----
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants