Skip to content

[mlir][sparse] infer returned type for sparse_tensor.to_[buffer] ops #83343

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 29, 2024

Conversation

PeimingLiu
Copy link
Member

@PeimingLiu PeimingLiu commented Feb 28, 2024

The sparse structure buffers might not always be memrefs with rank == 1 with the presence of batch levels.

@llvmbot
Copy link
Member

llvmbot commented Feb 28, 2024

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Peiming Liu (PeimingLiu)

Changes

…r] ops.


Patch is 21.37 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/83343.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td (+12-8)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+58)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+16-11)
  • (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_insert_3d.mlir (+37-134)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 9007e4e98e3163..3a5447d29f866d 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -257,9 +257,10 @@ def SparseTensor_ReinterpretMapOp : SparseTensor_Op<"reinterpret_map", [NoMemory
   let hasVerifier = 1;
 }
 
-def SparseTensor_ToPositionsOp : SparseTensor_Op<"positions", [Pure]>,
+def SparseTensor_ToPositionsOp : SparseTensor_Op<"positions",
+      [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
     Arguments<(ins AnySparseTensor:$tensor, LevelAttr:$level)>,
-    Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
+    Results<(outs AnyNon0RankedMemRef:$result)> {
   let summary = "Extracts the `level`-th positions array of the `tensor`";
   let description = [{
     Returns the positions array of the tensor's storage at the given
@@ -283,9 +284,10 @@ def SparseTensor_ToPositionsOp : SparseTensor_Op<"positions", [Pure]>,
   let hasVerifier = 1;
 }
 
-def SparseTensor_ToCoordinatesOp : SparseTensor_Op<"coordinates", [Pure]>,
+def SparseTensor_ToCoordinatesOp : SparseTensor_Op<"coordinates",
+      [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
     Arguments<(ins AnySparseTensor:$tensor, LevelAttr:$level)>,
-    Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
+    Results<(outs AnyNon0RankedMemRef:$result)> {
   let summary = "Extracts the `level`-th coordinates array of the `tensor`";
   let description = [{
     Returns the coordinates array of the tensor's storage at the given
@@ -309,9 +311,10 @@ def SparseTensor_ToCoordinatesOp : SparseTensor_Op<"coordinates", [Pure]>,
   let hasVerifier = 1;
 }
 
-def SparseTensor_ToCoordinatesBufferOp : SparseTensor_Op<"coordinates_buffer", [Pure]>,
+def SparseTensor_ToCoordinatesBufferOp : SparseTensor_Op<"coordinates_buffer",
+      [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
     Arguments<(ins AnySparseTensor:$tensor)>,
-    Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
+    Results<(outs AnyNon0RankedMemRef:$result)> {
   let summary = "Extracts the linear coordinates array from a tensor";
   let description = [{
     Returns the linear coordinates array for a sparse tensor with
@@ -340,9 +343,10 @@ def SparseTensor_ToCoordinatesBufferOp : SparseTensor_Op<"coordinates_buffer", [
   let hasVerifier = 1;
 }
 
-def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [Pure]>,
+def SparseTensor_ToValuesOp : SparseTensor_Op<"values",
+      [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
     Arguments<(ins AnySparseTensor:$tensor)>,
-    Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
+    Results<(outs AnyNon0RankedMemRef:$result)> {
   let summary = "Extracts numerical values array from a tensor";
   let description = [{
     Returns the values array of the sparse storage format for the given
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 69c3413f35ea9c..a77bcf18d39489 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1445,6 +1445,38 @@ OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
+template <typename ToBufferOp>
+static LogicalResult inferSparseBufferType(ValueRange ops,
+                                           SmallVectorImpl<mlir::Type> &ret) {
+  typename ToBufferOp::Adaptor adaptor(ops);
+  SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
+  Type elemTp = nullptr;
+  bool withStride;
+  if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
+    elemTp = stt.getPosType();
+    withStride = false;
+  } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
+                       std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
+    elemTp = stt.getCrdType();
+    withStride = std::is_same_v<ToBufferOp, ToCoordinatesOp> &&
+                 stt.getAoSCOOStart() < stt.getLvlRank();
+  } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
+    elemTp = stt.getElementType();
+    withStride = false;
+  }
+
+  assert(elemTp && "unhandled operation.");
+  SmallVector<int64_t> bufShape = stt.getBatchLvlShape();
+  bufShape.push_back(ShapedType::kDynamic);
+
+  auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get(
+                                 stt.getContext(), ShapedType::kDynamic,
+                                 {ShapedType::kDynamic})
+                           : StridedLayoutAttr();
+  ret.emplace_back(MemRefType::get(bufShape, elemTp, layout));
+  return success();
+}
+
 LogicalResult ToPositionsOp::verify() {
   auto stt = getSparseTensorType(getTensor());
   if (failed(lvlIsInBounds(getLevel(), getTensor())))
@@ -1454,6 +1486,12 @@ LogicalResult ToPositionsOp::verify() {
   return success();
 }
 
+LogicalResult ToPositionsOp::inferReturnTypes(
+    MLIRContext *, std::optional<Location>, ValueRange ops, DictionaryAttr,
+    OpaqueProperties, RegionRange, SmallVectorImpl<mlir::Type> &ret) {
+  return inferSparseBufferType<ToPositionsOp>(ops, ret);
+}
+
 LogicalResult ToCoordinatesOp::verify() {
   auto stt = getSparseTensorType(getTensor());
   if (failed(lvlIsInBounds(getLevel(), getTensor())))
@@ -1463,6 +1501,12 @@ LogicalResult ToCoordinatesOp::verify() {
   return success();
 }
 
+LogicalResult ToCoordinatesOp::inferReturnTypes(
+    MLIRContext *, std::optional<Location>, ValueRange ops, DictionaryAttr,
+    OpaqueProperties, RegionRange, SmallVectorImpl<mlir::Type> &ret) {
+  return inferSparseBufferType<ToCoordinatesOp>(ops, ret);
+}
+
 LogicalResult ToCoordinatesBufferOp::verify() {
   auto stt = getSparseTensorType(getTensor());
   if (stt.getAoSCOOStart() >= stt.getLvlRank())
@@ -1470,6 +1514,12 @@ LogicalResult ToCoordinatesBufferOp::verify() {
   return success();
 }
 
+LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
+    MLIRContext *, std::optional<Location>, ValueRange ops, DictionaryAttr,
+    OpaqueProperties, RegionRange, SmallVectorImpl<mlir::Type> &ret) {
+  return inferSparseBufferType<ToCoordinatesBufferOp>(ops, ret);
+}
+
 LogicalResult ToValuesOp::verify() {
   auto stt = getSparseTensorType(getTensor());
   auto mtp = getMemRefType(getResult());
@@ -1478,6 +1528,14 @@ LogicalResult ToValuesOp::verify() {
   return success();
 }
 
+LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *,
+                                           std::optional<Location>,
+                                           ValueRange ops, DictionaryAttr,
+                                           OpaqueProperties, RegionRange,
+                                           SmallVectorImpl<mlir::Type> &ret) {
+  return inferSparseBufferType<ToValuesOp>(ops, ret);
+}
+
 LogicalResult ToSliceOffsetOp::verify() {
   auto rank = getRankedTensorType(getSlice()).getRank();
   if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index c95b7b015b3725..6ff21468e05764 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -618,10 +618,10 @@ struct PrintRewriter : public OpRewritePattern<PrintOp> {
     rewriter.create<vector::PrintOp>(loc, nse);
     // Use the "codegen" foreach loop construct to iterate over
     // all typical sparse tensor components for printing.
-    foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc,
-                                            &tensor](Type tp, FieldIndex,
-                                                     SparseTensorFieldKind kind,
-                                                     Level l, LevelType) {
+    foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc, &tensor,
+                                            &stt](Type, FieldIndex,
+                                                  SparseTensorFieldKind kind,
+                                                  Level l, LevelType) {
       switch (kind) {
       case SparseTensorFieldKind::StorageSpec: {
         break;
@@ -632,8 +632,8 @@ struct PrintRewriter : public OpRewritePattern<PrintOp> {
         rewriter.create<vector::PrintOp>(
             loc, lvl, vector::PrintPunctuation::NoPunctuation);
         rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
-        auto pos = rewriter.create<ToPositionsOp>(loc, tp, tensor, l);
-        printContents(rewriter, loc, tp, pos);
+        auto pos = rewriter.create<ToPositionsOp>(loc, tensor, l);
+        printContents(rewriter, loc, pos);
         break;
       }
       case SparseTensorFieldKind::CrdMemRef: {
@@ -642,15 +642,20 @@ struct PrintRewriter : public OpRewritePattern<PrintOp> {
         rewriter.create<vector::PrintOp>(
             loc, lvl, vector::PrintPunctuation::NoPunctuation);
         rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
-        auto crd = rewriter.create<ToCoordinatesOp>(loc, tp, tensor, l);
-        printContents(rewriter, loc, tp, crd);
+        Value crd = nullptr;
+        // TODO: eliminates ToCoordinateBufferOp!
+        if (stt.getAoSCOOStart() == l)
+          crd = rewriter.create<ToCoordinatesBufferOp>(loc, tensor);
+        else
+          crd = rewriter.create<ToCoordinatesOp>(loc, tensor, l);
+        printContents(rewriter, loc, crd);
         break;
       }
       case SparseTensorFieldKind::ValMemRef: {
         rewriter.create<vector::PrintOp>(loc,
                                          rewriter.getStringAttr("values : "));
-        auto val = rewriter.create<ToValuesOp>(loc, tp, tensor);
-        printContents(rewriter, loc, tp, val);
+        auto val = rewriter.create<ToValuesOp>(loc, tensor);
+        printContents(rewriter, loc, val);
         break;
       }
       }
@@ -670,7 +675,7 @@ struct PrintRewriter : public OpRewritePattern<PrintOp> {
   //
   // Generates code to print:
   //    ( a0, a1, ... )
-  static void printContents(PatternRewriter &rewriter, Location loc, Type tp,
+  static void printContents(PatternRewriter &rewriter, Location loc,
                             Value vec) {
     // Open bracket.
     rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_insert_3d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_insert_3d.mlir
index c141df64c22e76..3a32ff28527001 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_insert_3d.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_insert_3d.mlir
@@ -45,91 +45,6 @@
 
 
 module {
-
-  func.func @dump(%arg0: tensor<5x4x3xf64, #TensorCSR>) {
-    %c0 = arith.constant 0 : index
-    %fu = arith.constant 99.0 : f64
-    %p0 = sparse_tensor.positions %arg0 { level = 0 : index } : tensor<5x4x3xf64, #TensorCSR> to memref<?xindex>
-    %i0 = sparse_tensor.coordinates  %arg0 { level = 0 : index } : tensor<5x4x3xf64, #TensorCSR> to memref<?xindex>
-    %p2 = sparse_tensor.positions %arg0 { level = 2 : index } : tensor<5x4x3xf64, #TensorCSR> to memref<?xindex>
-    %i2 = sparse_tensor.coordinates  %arg0 { level = 2 : index } : tensor<5x4x3xf64, #TensorCSR> to memref<?xindex>
-    %v = sparse_tensor.values %arg0 : tensor<5x4x3xf64, #TensorCSR> to memref<?xf64>
-    %vp0 = vector.transfer_read %p0[%c0], %c0: memref<?xindex>, vector<2xindex>
-    vector.print %vp0 : vector<2xindex>
-    %vi0 = vector.transfer_read %i0[%c0], %c0: memref<?xindex>, vector<2xindex>
-    vector.print %vi0 : vector<2xindex>
-    %vp2 = vector.transfer_read %p2[%c0], %c0: memref<?xindex>, vector<9xindex>
-    vector.print %vp2 : vector<9xindex>
-    %vi2 = vector.transfer_read %i2[%c0], %c0: memref<?xindex>, vector<5xindex>
-    vector.print %vi2 : vector<5xindex>
-    %vv = vector.transfer_read %v[%c0], %fu: memref<?xf64>, vector<5xf64>
-    vector.print %vv : vector<5xf64>
-    return
-  }
-
-  func.func @dump_row(%arg0: tensor<5x4x3xf64, #TensorRow>) {
-    %c0 = arith.constant 0 : index
-    %fu = arith.constant 99.0 : f64
-    %p0 = sparse_tensor.positions %arg0 { level = 0 : index } : tensor<5x4x3xf64, #TensorRow> to memref<?xindex>
-    %i0 = sparse_tensor.coordinates  %arg0 { level = 0 : index } : tensor<5x4x3xf64, #TensorRow> to memref<?xindex>
-    %p1 = sparse_tensor.positions %arg0 { level = 1 : index } : tensor<5x4x3xf64, #TensorRow> to memref<?xindex>
-    %i1 = sparse_tensor.coordinates  %arg0 { level = 1 : index } : tensor<5x4x3xf64, #TensorRow> to memref<?xindex>
-    %v = sparse_tensor.values %arg0 : tensor<5x4x3xf64, #TensorRow> to memref<?xf64>
-    %vp0 = vector.transfer_read %p0[%c0], %c0: memref<?xindex>, vector<2xindex>
-    vector.print %vp0 : vector<2xindex>
-    %vi0 = vector.transfer_read %i0[%c0], %c0: memref<?xindex>, vector<2xindex>
-    vector.print %vi0 : vector<2xindex>
-    %vp1 = vector.transfer_read %p1[%c0], %c0: memref<?xindex>, vector<3xindex>
-    vector.print %vp1 : vector<3xindex>
-    %vi1 = vector.transfer_read %i1[%c0], %c0: memref<?xindex>, vector<4xindex>
-    vector.print %vi1 : vector<4xindex>
-    %vv = vector.transfer_read %v[%c0], %fu: memref<?xf64>, vector<12xf64>
-    vector.print %vv : vector<12xf64>
-    return
-  }
-
-func.func @dump_ccoo(%arg0: tensor<5x4x3xf64, #CCoo>) {
-    %c0 = arith.constant 0 : index
-    %fu = arith.constant 99.0 : f64
-    %p0 = sparse_tensor.positions %arg0 { level = 0 : index } : tensor<5x4x3xf64, #CCoo> to memref<?xindex>
-    %i0 = sparse_tensor.coordinates  %arg0 { level = 0 : index } : tensor<5x4x3xf64, #CCoo> to memref<?xindex>
-    %p1 = sparse_tensor.positions %arg0 { level = 1 : index } : tensor<5x4x3xf64, #CCoo> to memref<?xindex>
-    %i1 = sparse_tensor.coordinates  %arg0 { level = 1 : index } : tensor<5x4x3xf64, #CCoo> to memref<?xindex>
-    %i2 = sparse_tensor.coordinates  %arg0 { level = 2 : index } : tensor<5x4x3xf64, #CCoo> to memref<?xindex>
-    %v = sparse_tensor.values %arg0 : tensor<5x4x3xf64, #CCoo> to memref<?xf64>
-    %vp0 = vector.transfer_read %p0[%c0], %c0: memref<?xindex>, vector<2xindex>
-    vector.print %vp0 : vector<2xindex>
-    %vi0 = vector.transfer_read %i0[%c0], %c0: memref<?xindex>, vector<2xindex>
-    vector.print %vi0 : vector<2xindex>
-    %vp1 = vector.transfer_read %p1[%c0], %c0: memref<?xindex>, vector<3xindex>
-    vector.print %vp1 : vector<3xindex>
-    %vi1 = vector.transfer_read %i1[%c0], %c0: memref<?xindex>, vector<5xindex>
-    vector.print %vi1 : vector<5xindex>
-    %vi2 = vector.transfer_read %i2[%c0], %c0: memref<?xindex>, vector<5xindex>
-    vector.print %vi2 : vector<5xindex>
-    %vv = vector.transfer_read %v[%c0], %fu: memref<?xf64>, vector<5xf64>
-    vector.print %vv : vector<5xf64>
-    return
-  }
-
-func.func @dump_dcoo(%arg0: tensor<5x4x3xf64, #DCoo>) {
-    %c0 = arith.constant 0 : index
-    %fu = arith.constant 99.0 : f64
-    %p1 = sparse_tensor.positions %arg0 { level = 1 : index } : tensor<5x4x3xf64, #DCoo> to memref<?xindex>
-    %i1 = sparse_tensor.coordinates  %arg0 { level = 1 : index } : tensor<5x4x3xf64, #DCoo> to memref<?xindex>
-    %i2 = sparse_tensor.coordinates  %arg0 { level = 2 : index } : tensor<5x4x3xf64, #DCoo> to memref<?xindex>
-    %v = sparse_tensor.values %arg0 : tensor<5x4x3xf64, #DCoo> to memref<?xf64>
-    %vp1 = vector.transfer_read %p1[%c0], %c0: memref<?xindex>, vector<6xindex>
-    vector.print %vp1 : vector<6xindex>
-    %vi1 = vector.transfer_read %i1[%c0], %c0: memref<?xindex>, vector<5xindex>
-    vector.print %vi1 : vector<5xindex>
-    %vi2 = vector.transfer_read %i2[%c0], %c0: memref<?xindex>, vector<5xindex>
-    vector.print %vi2 : vector<5xindex>
-    %vv = vector.transfer_read %v[%c0], %fu: memref<?xf64>, vector<5xf64>
-    vector.print %vv : vector<5xf64>
-    return
-}
-
   //
   // Main driver.
   //
@@ -145,13 +60,14 @@ func.func @dump_dcoo(%arg0: tensor<5x4x3xf64, #DCoo>) {
     %f4 = arith.constant 4.4 : f64
     %f5 = arith.constant 5.5 : f64
 
-    //
-    // CHECK:      ( 0, 2 )
-    // CHECK-NEXT: ( 3, 4 )
-    // CHECK-NEXT: ( 0, 2, 2, 2, 3, 3, 3, 4, 5 )
-    // CHECK-NEXT: ( 1, 2, 1, 2, 2 )
-    // CHECK-NEXT: ( 1.1, 2.2, 3.3, 4.4, 5.5 )
-    //
+    // CHECK: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 5
+    // CHECK-NEXT: pos[0] : ( 0, 2
+    // CHECK-NEXT: crd[0] : ( 3, 4
+    // CHECK-NEXT: pos[2] : ( 0, 2, 2, 2, 3, 3, 3, 4, 5
+    // CHECK-NEXT: crd[2] : ( 1, 2, 1, 2, 2
+    // CHECK-NEXT: values : ( 1.1, 2.2, 3.3, 4.4, 5.5
+    // CHECK-NEXT: ----
     %tensora = tensor.empty() : tensor<5x4x3xf64, #TensorCSR>
     %tensor1 = sparse_tensor.insert %f1 into %tensora[%c3, %c0, %c1] : tensor<5x4x3xf64, #TensorCSR>
     %tensor2 = sparse_tensor.insert %f2 into %tensor1[%c3, %c0, %c2] : tensor<5x4x3xf64, #TensorCSR>
@@ -159,15 +75,16 @@ func.func @dump_dcoo(%arg0: tensor<5x4x3xf64, #DCoo>) {
     %tensor4 = sparse_tensor.insert %f4 into %tensor3[%c4, %c2, %c2] : tensor<5x4x3xf64, #TensorCSR>
     %tensor5 = sparse_tensor.insert %f5 into %tensor4[%c4, %c3, %c2] : tensor<5x4x3xf64, #TensorCSR>
     %tensorm = sparse_tensor.load %tensor5 hasInserts : tensor<5x4x3xf64, #TensorCSR>
-    call @dump(%tensorm) : (tensor<5x4x3xf64, #TensorCSR>) -> ()
-
-    //
-    // CHECK-NEXT: ( 0, 2 )
-    // CHECK-NEXT: ( 3, 4 )
-    // CHECK-NEXT: ( 0, 2, 4 )
-    // CHECK-NEXT: ( 0, 3, 2, 3 )
-    // CHECK-NEXT: ( 0, 1.1, 2.2, 0, 3.3, 0, 0, 0, 4.4, 0, 0, 5.5 )
-    //
+    sparse_tensor.print %tensorm : tensor<5x4x3xf64, #TensorCSR>
+
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 12
+    // CHECK-NEXT: pos[0] : ( 0, 2
+    // CHECK-NEXT: crd[0] : ( 3, 4
+    // CHECK-NEXT: pos[1] : ( 0, 2, 4
+    // CHECK-NEXT: crd[1] : ( 0, 3, 2, 3
+    // CHECK-NEXT: values : ( 0, 1.1, 2.2, 0, 3.3, 0, 0, 0, 4.4, 0, 0, 5.5
+    // CHECK-NEXT: ----
     %rowa = tensor.empty() : tensor<5x4x3xf64, #TensorRow>
     %row1 = sparse_tensor.insert %f1 into %rowa[%c3, %c0, %c1] : tensor<5x4x3xf64, #TensorRow>
     %row2 = sparse_tensor.insert %f2 into %row1[%c3, %c0, %c2] : tensor<5x4x3xf64, #TensorRow>
@@ -175,15 +92,16 @@ func.func @dump_dcoo(%arg0: tensor<5x4x3xf64, #DCoo>) {
     %row4 = sparse_tensor.insert %f4 into %row3[%c4, %c2, %c2] : tensor<5x4x3xf64, #TensorRow>
     %row5 = sparse_tensor.insert %f5 into %row4[%c4, %c3, %c2] : tensor<5x4x3xf64, #TensorRow>
     %rowm = sparse_tensor.load %row5 hasInserts : tensor<5x4x3xf64, #TensorRow>
-    call @dump_row(%rowm) : (tensor<5x4x3xf64, #TensorRow>) -> ()
-
-    //
-    // CHECK: ( 0, 2 )
-    // CHECK-NEXT: ( 3, 4 )
-    // CHECK-NEXT: ( 0, 3, 5 )
-    // CHECK-NEXT: ( 0, 0, 3, 2, 3 )
-    // CHECK-NEXT: ( 1, 2, 1, 2, 2 )
-    // CHECK-NEXT: ( 1.1, 2.2, 3.3, 4.4, 5.5 )
+    sparse_tensor.print %rowm : tensor<5x4x3xf64, #TensorRow>
+
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 5
+    // CHECK-NEXT: pos[0] : ( 0, 2
+    // CHECK-NEXT: crd[0] : ( 3, 4
+    // CHECK-NEXT: pos[1] : ( 0, 3, 5
+    // CHECK-NEXT: crd[1] : ( 0, 1, 0, 2, 3, 1, 2, 2, 3, 2
+    // CHECK-NEXT: values : ( 1.1, 2.2, 3.3, 4.4, 5.5
+    // CHECK-NEXT: ----
     %ccoo = tensor.empty() : tensor<5x4x3xf64, #CCoo>
     %ccoo1 = sparse_tensor.insert %f1 into %ccoo[%c3, %c0, %c1] : tensor<5x4x3xf64, #CCoo>
     %ccoo2 = sparse_tensor.insert %f2 into %ccoo1[%c3, %c0, %c2] : tensor<5x4x3xf64, #CCoo>
@@ -191,13 +109,14 @@ func.func @dump_dcoo(%arg0: tensor<5x4x3xf64, #DCoo>) {
     %ccoo4 = sparse_tensor.insert %f4 into %ccoo3[%c4, %c2, %c2] : tensor<5x4x3xf64, #CCoo>
     %ccoo5 = sparse_tensor.insert %f5 into %ccoo4[%c4, %c3, %c2] : tensor<5x4x3xf64, #CCoo>
     %ccoom = sparse_tensor.load %ccoo5 hasInserts : tensor<5x4x3xf64, #CCoo>
-    call @dump_ccoo(%ccoom) : (tensor<5x4x3xf64, #CCoo>) -> ()
-
-    //
-    // CHECK-NEXT: ( 0, 0, 0, 0, 3, 5 )
-    // CHECK-NEXT: ( 0, 0, 3, 2, 3 )
-    // CHECK-NEXT: ( 1, 2, 1, 2, 2 )
-    // CHECK-NEXT: ( 1.1, 2.2, 3.3, 4.4, 5.5 )
+    sparse_tensor.print %ccoom : tensor<5x4x3xf64, #CCoo>
+
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 5
+    // CHECK-NEXT: pos[1] : ( 0, 0, 0, 0, 3, 5
+    // CHECK-NEXT: crd[1] : ( 0, 1, 0, 2, 3, 1, 2, 2, 3, 2
+    // CHECK-NEXT: values : ( 1.1, 2.2, 3.3, 4.4, 5.5
+    // CHECK-NEXT: ...
[truncated]

@PeimingLiu PeimingLiu changed the title [mlir][sparse] infer returned type for sparse_tensor.to_[sparse_buffe… [mlir][sparse] infer returned type for sparse_tensor.to_[buffer] ops Feb 28, 2024
@PeimingLiu PeimingLiu force-pushed the infer-memref branch 2 times, most recently from 6e848f6 to 2d796eb Compare February 28, 2024 22:03
@PeimingLiu PeimingLiu merged commit 6bc7c9d into llvm:main Feb 29, 2024
@PeimingLiu PeimingLiu deleted the infer-memref branch February 29, 2024 00:11
mylai-mtk pushed a commit to mylai-mtk/llvm-project that referenced this pull request Jul 12, 2024
…lvm#83343)

The sparse structure buffers might not always be memrefs with rank == 1
with the presence of batch levels.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants