Skip to content

[mlir][sparse] add a sparse_tensor.print operation #83321

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1453,4 +1453,26 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// Sparse Tensor Debugging Operations.
//===----------------------------------------------------------------------===//

def SparseTensor_PrintOp : SparseTensor_Op<"print">,
Arguments<(ins AnySparseTensor:$tensor)> {
string summary = "Prints a sparse tensor (for testing and debugging)";
string description = [{
Prints the individual components of a sparse tensors (the positions,
coordinates, and values components) to stdout for testing and debugging
purposes. This operation lowers to just a few primitives in a light-weight
runtime support to simplify supporting this operation on new platforms.

Example:

```mlir
sparse_tensor.print %tensor : tensor<1024x1024xf64, #CSR>
```
}];
let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
}

#endif // SPARSETENSOR_OPS
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LLVM.h"
Expand Down Expand Up @@ -598,6 +600,96 @@ struct GenSemiRingReduction : public OpRewritePattern<GenericOp> {
}
};

/// Sparse rewriting rule for the print operator. This operation is mainly used
/// for debugging and testing. As such, it lowers to the vector.print operation
/// which only require very light-weight runtime support.
struct PrintRewriter : public OpRewritePattern<PrintOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PrintOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto tensor = op.getTensor();
auto stt = getSparseTensorType(tensor);
// Header with NSE.
auto nse = rewriter.create<NumberOfEntriesOp>(loc, tensor);
rewriter.create<vector::PrintOp>(
loc, rewriter.getStringAttr("---- Sparse Tensor ----\nnse = "));
rewriter.create<vector::PrintOp>(loc, nse);
// Use the "codegen" foreach loop construct to iterate over
// all typical sparse tensor components for printing.
foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc,
&tensor](Type tp, FieldIndex,
SparseTensorFieldKind kind,
Level l, LevelType) {
switch (kind) {
case SparseTensorFieldKind::StorageSpec: {
break;
}
case SparseTensorFieldKind::PosMemRef: {
auto lvl = constantIndex(rewriter, loc, l);
rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("pos["));
rewriter.create<vector::PrintOp>(
loc, lvl, vector::PrintPunctuation::NoPunctuation);
rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
auto pos = rewriter.create<ToPositionsOp>(loc, tp, tensor, l);
printContents(rewriter, loc, tp, pos);
break;
}
case SparseTensorFieldKind::CrdMemRef: {
auto lvl = constantIndex(rewriter, loc, l);
rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("crd["));
rewriter.create<vector::PrintOp>(
loc, lvl, vector::PrintPunctuation::NoPunctuation);
rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
auto crd = rewriter.create<ToCoordinatesOp>(loc, tp, tensor, l);
printContents(rewriter, loc, tp, crd);
break;
}
case SparseTensorFieldKind::ValMemRef: {
rewriter.create<vector::PrintOp>(loc,
rewriter.getStringAttr("values : "));
auto val = rewriter.create<ToValuesOp>(loc, tp, tensor);
printContents(rewriter, loc, tp, val);
break;
}
}
return true;
});
rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("----\n"));
rewriter.eraseOp(op);
return success();
}

private:
// Helper to print contents of a single memref. Note that for the "push_back"
// vectors, this prints the full capacity, not just the size. This is done
// on purpose, so that clients see how much storage has been allocated in
// total. Contents of the extra capacity in the buffer may be uninitialized
// (unless the flag enable-buffer-initialization is set to true).
//
// Generates code to print:
// ( a0, a1, ... )
static void printContents(PatternRewriter &rewriter, Location loc, Type tp,
Value vec) {
// Open bracket.
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
// For loop over elements.
auto zero = constantIndex(rewriter, loc, 0);
auto size = rewriter.create<memref::DimOp>(loc, vec, zero);
auto step = constantIndex(rewriter, loc, 1);
auto forOp = rewriter.create<scf::ForOp>(loc, zero, size, step);
rewriter.setInsertionPointToStart(forOp.getBody());
auto idx = forOp.getInductionVar();
auto val = rewriter.create<memref::LoadOp>(loc, vec, idx);
rewriter.create<vector::PrintOp>(loc, val, vector::PrintPunctuation::Comma);
rewriter.setInsertionPointAfter(forOp);
// Close bracket and end of line.
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
}
};

/// Sparse rewriting rule for sparse-to-sparse reshape operator.
struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
public:
Expand Down Expand Up @@ -1284,7 +1376,8 @@ struct OutRewriter : public OpRewritePattern<OutOp> {

void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
GenSemiRingReduction, GenSemiRingSelect>(patterns.getContext());
GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
patterns.getContext());
}

void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
Expand Down
215 changes: 215 additions & 0 deletions mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
//--------------------------------------------------------------------------------------------------
// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS.
//
// Set-up that's shared across all tests in this directory. In principle, this
// config could be moved to lit.local.cfg. However, there are downstream users that
// do not use these LIT config files. Hence why this is kept inline.
//
// DEFINE: %{sparsifier_opts} = enable-runtime-library=true
// DEFINE: %{sparsifier_opts_sve} = enable-arm-sve=true %{sparsifier_opts}
// DEFINE: %{compile} = mlir-opt %s --sparsifier="%{sparsifier_opts}"
// DEFINE: %{compile_sve} = mlir-opt %s --sparsifier="%{sparsifier_opts_sve}"
// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
// DEFINE: %{run_opts} = -e main -entry-point-result=void
// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
//
// DEFINE: %{env} =
//--------------------------------------------------------------------------------------------------

// RUN: %{compile} | %{run} | FileCheck %s
//
// Do the same run, but now with direct IR generation.
// REDEFINE: %{sparsifier_opts} = enable-runtime-library=false enable-buffer-initialization=true
// RUN: %{compile} | %{run} | FileCheck %s
//

#AllDense = #sparse_tensor.encoding<{
map = (i, j) -> (
i : dense,
j : dense
)
}>

#AllDenseT = #sparse_tensor.encoding<{
map = (i, j) -> (
j : dense,
i : dense
)
}>

#CSR = #sparse_tensor.encoding<{
map = (i, j) -> (
i : dense,
j : compressed
)
}>

#DCSR = #sparse_tensor.encoding<{
map = (i, j) -> (
i : compressed,
j : compressed
)
}>

#CSC = #sparse_tensor.encoding<{
map = (i, j) -> (
j : dense,
i : compressed
)
}>

#DCSC = #sparse_tensor.encoding<{
map = (i, j) -> (
j : compressed,
i : compressed
)
}>

#BSR = #sparse_tensor.encoding<{
map = (i, j) -> (
i floordiv 2 : compressed,
j floordiv 4 : compressed,
i mod 2 : dense,
j mod 4 : dense
)
}>

#BSRC = #sparse_tensor.encoding<{
map = (i, j) -> (
i floordiv 2 : compressed,
j floordiv 4 : compressed,
j mod 4 : dense,
i mod 2 : dense
)
}>

#BSC = #sparse_tensor.encoding<{
map = (i, j) -> (
j floordiv 4 : compressed,
i floordiv 2 : compressed,
i mod 2 : dense,
j mod 4 : dense
)
}>

#BSCC = #sparse_tensor.encoding<{
map = (i, j) -> (
j floordiv 4 : compressed,
i floordiv 2 : compressed,
j mod 4 : dense,
i mod 2 : dense
)
}>

module {

//
// Main driver that tests sparse tensor storage.
//
func.func @main() {
%x = arith.constant dense <[
[ 1, 0, 2, 0, 0, 0, 0, 0 ],
[ 0, 0, 0, 0, 0, 0, 0, 0 ],
[ 0, 0, 0, 0, 0, 0, 0, 0 ],
[ 0, 0, 3, 4, 0, 5, 0, 0 ] ]> : tensor<4x8xi32>

%a = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #CSR>
%b = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #DCSR>
%c = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #CSC>
%d = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #DCSC>
%e = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSR>
%f = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSRC>
%g = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSC>
%h = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSCC>

//
// CHECK: ---- Sparse Tensor ----
// CHECK-NEXT: nse = 5
// CHECK-NEXT: pos[1] : ( 0, 2, 2, 2, 5,
// CHECK-NEXT: crd[1] : ( 0, 2, 2, 3, 5,
// CHECK-NEXT: values : ( 1, 2, 3, 4, 5,
// CHECK-NEXT: ----
sparse_tensor.print %a : tensor<4x8xi32, #CSR>

// CHECK-NEXT: ---- Sparse Tensor ----
// CHECK-NEXT: nse = 5
// CHECK-NEXT: pos[0] : ( 0, 2,
// CHECK-NEXT: crd[0] : ( 0, 3,
// CHECK-NEXT: pos[1] : ( 0, 2, 5,
// CHECK-NEXT: crd[1] : ( 0, 2, 2, 3, 5,
// CHECK-NEXT: values : ( 1, 2, 3, 4, 5,
// CHECK-NEXT: ----
sparse_tensor.print %b : tensor<4x8xi32, #DCSR>

// CHECK-NEXT: ---- Sparse Tensor ----
// CHECK-NEXT: nse = 5
// CHECK-NEXT: pos[1] : ( 0, 1, 1, 3, 4, 4, 5, 5, 5,
// CHECK-NEXT: crd[1] : ( 0, 0, 3, 3, 3,
// CHECK-NEXT: values : ( 1, 2, 3, 4, 5,
// CHECK-NEXT: ----
sparse_tensor.print %c : tensor<4x8xi32, #CSC>

// CHECK-NEXT: ---- Sparse Tensor ----
// CHECK-NEXT: nse = 5
// CHECK-NEXT: pos[0] : ( 0, 4,
// CHECK-NEXT: crd[0] : ( 0, 2, 3, 5,
// CHECK-NEXT: pos[1] : ( 0, 1, 3, 4, 5,
// CHECK-NEXT: crd[1] : ( 0, 0, 3, 3, 3,
// CHECK-NEXT: values : ( 1, 2, 3, 4, 5,
// CHECK-NEXT: ----
sparse_tensor.print %d : tensor<4x8xi32, #DCSC>

// CHECK-NEXT: ---- Sparse Tensor ----
// CHECK-NEXT: nse = 24
// CHECK-NEXT: pos[0] : ( 0, 2,
// CHECK-NEXT: crd[0] : ( 0, 1,
// CHECK-NEXT: pos[1] : ( 0, 1, 3,
// CHECK-NEXT: crd[1] : ( 0, 0, 1,
// CHECK-NEXT: values : ( 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 4, 0, 0, 0, 0, 0, 5, 0, 0,
// CHECK-NEXT: ----
sparse_tensor.print %e : tensor<4x8xi32, #BSR>

// CHECK-NEXT: ---- Sparse Tensor ----
// CHECK-NEXT: nse = 24
// CHECK-NEXT: pos[0] : ( 0, 2,
// CHECK-NEXT: crd[0] : ( 0, 1,
// CHECK-NEXT: pos[1] : ( 0, 1, 3,
// CHECK-NEXT: crd[1] : ( 0, 0, 1,
// CHECK-NEXT: values : ( 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 4, 0, 0, 0, 5, 0, 0, 0, 0,
// CHECK-NEXT: ----
sparse_tensor.print %f : tensor<4x8xi32, #BSRC>

// CHECK-NEXT: ---- Sparse Tensor ----
// CHECK-NEXT: nse = 24
// CHECK-NEXT: pos[0] : ( 0, 2,
// CHECK-NEXT: crd[0] : ( 0, 1,
// CHECK-NEXT: pos[1] : ( 0, 2, 3,
// CHECK-NEXT: crd[1] : ( 0, 1, 1,
// CHECK-NEXT: values : ( 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 4, 0, 0, 0, 0, 0, 5, 0, 0,
// CHECK-NEXT: ----
sparse_tensor.print %g : tensor<4x8xi32, #BSC>

// CHECK-NEXT: ---- Sparse Tensor ----
// CHECK-NEXT: nse = 24
// CHECK-NEXT: pos[0] : ( 0, 2,
// CHECK-NEXT: crd[0] : ( 0, 1,
// CHECK-NEXT: pos[1] : ( 0, 2, 3,
// CHECK-NEXT: crd[1] : ( 0, 1, 1,
// CHECK-NEXT: values : ( 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 4, 0, 0, 0, 5, 0, 0, 0, 0,
// CHECK-NEXT: ----
sparse_tensor.print %h : tensor<4x8xi32, #BSCC>

// Release the resources.
bufferization.dealloc_tensor %a : tensor<4x8xi32, #CSR>
bufferization.dealloc_tensor %b : tensor<4x8xi32, #DCSR>
bufferization.dealloc_tensor %c : tensor<4x8xi32, #CSC>
bufferization.dealloc_tensor %d : tensor<4x8xi32, #DCSC>
bufferization.dealloc_tensor %e : tensor<4x8xi32, #BSR>
bufferization.dealloc_tensor %f : tensor<4x8xi32, #BSRC>
bufferization.dealloc_tensor %g : tensor<4x8xi32, #BSC>
bufferization.dealloc_tensor %h : tensor<4x8xi32, #BSCC>

return
}
}