Skip to content

[mlir][sparse] support non-id map for [Dis]assembleOp #80355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 1, 2024

Conversation

PeimingLiu
Copy link
Member

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Feb 1, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sparse

Author: Peiming Liu (PeimingLiu)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/80355.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (-2)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp (+36-1)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir (+48)
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 6033ebf6897ce..27125bc7ed45e 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1016,8 +1016,6 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
     return op->emitError("the sparse-tensor must have static shape");
   if (!stt.hasEncoding())
     return op->emitError("the sparse-tensor must have an encoding attribute");
-  if (!stt.isIdentity())
-    return op->emitError("the sparse-tensor must have the identity mapping");
 
   // Verifies the trailing COO.
   Level cooStartLvl = stt.getCOOStart();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index a0f7b55ce4446..fbe2fc31ab8b1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -656,6 +656,40 @@ struct TensorInsertDemapper
   }
 };
 
+struct SparseAssembleDemapper : public OpRewritePattern<AssembleOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(AssembleOp op,
+                                PatternRewriter &rewriter) const override {
+    if (!hasAnyNonIdentityOperandsOrResults(op))
+      return failure();
+
+    assert(hasAnySparseResult(op));
+    auto stt = getSparseTensorType(op.getResult());
+    rewriter.modifyOpInPlace(
+        op, [&op, &stt]() { op.getResult().setType(stt.getDemappedType()); });
+    rewriter.setInsertionPointAfter(op);
+    Value out = genRemap(rewriter, stt.getEncoding(), op.getResult());
+    rewriter.replaceAllUsesExcept(op, out, out.getDefiningOp());
+    return success();
+  }
+};
+
+struct SparseDisassembleDemapper
+    : public DemapInsRewriter<SparseDisassembleDemapper, DisassembleOp> {
+  using DemapInsRewriter::DemapInsRewriter;
+  LogicalResult rewriteOp(DisassembleOp op, OpAdaptor adaptor,
+                          PatternRewriter &rewriter) const {
+    if (!hasAnyNonIdentityOperandsOrResults(op))
+      return failure();
+
+    assert(hasAnySparseOperandOrResult(op));
+    rewriter.modifyOpInPlace(op, [&op, &adaptor]() {
+      op.getTensorMutable().assign(adaptor.getTensor());
+    });
+    return success();
+  }
+};
+
 struct ForeachOpDemapper
     : public DemapInsRewriter<ForeachOpDemapper, ForeachOp> {
   using DemapInsRewriter::DemapInsRewriter;
@@ -758,7 +792,8 @@ void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
   if (scope == ReinterpretMapScope::kAll ||
       scope == ReinterpretMapScope::kExceptGeneric) {
     patterns.add<TensorAllocDemapper<bufferization::AllocTensorOp>,
-                 TensorAllocDemapper<tensor::EmptyOp>, TensorInsertDemapper,
+                 TensorAllocDemapper<tensor::EmptyOp>, SparseAssembleDemapper,
+                 SparseDisassembleDemapper, TensorInsertDemapper,
                  ForeachOpDemapper>(patterns.getContext());
   }
 }
diff --git a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
index 46f04cca03ed7..54de1024323b5 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
@@ -80,3 +80,51 @@ func.func @sparse_foreach_reinterpret_map(%6 : tensor<2x4xf64, #BSR>) -> tensor<
   %9 = sparse_tensor.load %8 hasInserts : tensor<2x4xf64, #BSR>
   return %9 : tensor<2x4xf64, #BSR>
 }
+
+
+// -----
+
+#BSR = #sparse_tensor.encoding<{
+   map = ( i, j ) ->
+      ( i floordiv 2 : dense,
+        j floordiv 2 : compressed,
+        i mod 2      : dense,
+        j mod 2      : dense
+      )
+}>
+// CHECK-DAG: #[[$remap:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 floordiv 2 : dense, d1 floordiv 2 : compressed, d0 mod 2 : dense, d1 mod 2 : dense) }>
+// CHECK-DAG: #[[$demap:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3) -> (d0 : dense, d1 : compressed, d2 : dense, d3 : dense) }>
+
+// CHECK-LABEL:   func.func @sparse_assemble_reinterpret_map(
+// CHECK-SAME:        %[[VAL_0:.*]]: tensor<?xf64>,
+// CHECK-SAME:        %[[VAL_1:.*]]: tensor<?xindex>,
+// CHECK-SAME:        %[[VAL_2:.*]]: tensor<?xindex>) -> tensor<2x4xf64, #[[$remap]]> {
+// CHECK:           %[[VAL_3:.*]] = sparse_tensor.assemble %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : tensor<?xf64>, tensor<?xindex>, tensor<?xindex> to tensor<1x2x2x2xf64, #[[$demap]]>
+// CHECK:           %[[VAL_4:.*]] = sparse_tensor.reinterpret_map %[[VAL_3]] : tensor<1x2x2x2xf64, #[[$demap]]> to tensor<2x4xf64, #[[$remap]]>
+// CHECK:           return %[[VAL_4]] : tensor<2x4xf64, #[[$remap]]>
+// CHECK:         }
+func.func @sparse_assemble_reinterpret_map(%val : tensor<?xf64>, %pos:tensor<?xindex>, %crd:tensor<?xindex>) -> tensor<2x4xf64, #BSR> {
+  %0 = sparse_tensor.assemble %val, %pos, %crd
+     : tensor<?xf64>, tensor<?xindex>, tensor<?xindex> to tensor<2x4xf64, #BSR>
+  return %0 : tensor<2x4xf64, #BSR>
+}
+
+// CHECK-LABEL:   func.func @sparse_disassemble_reinterpret_map(
+// CHECK-SAME:         %[[VAL_0:.*]]: tensor<2x4xf64, #[[$remap]]>,
+// CHECK-SAME:         %[[VAL_1:.*]]: tensor<?xf64>,
+// CHECK-SAME:         %[[VAL_2:.*]]: tensor<?xindex>,
+// CHECK-SAME:         %[[VAL_3:.*]]: tensor<?xindex>) -> (tensor<?xf64>, tensor<?xindex>, tensor<?xindex>) {
+// CHECK:           %[[VAL_4:.*]] = sparse_tensor.reinterpret_map %[[VAL_0]] : tensor<2x4xf64, #[[$remap]]> to tensor<1x2x2x2xf64, #[[$demap]]>
+// CHECK:           %[[VAL_5:.*]], %[[VAL_6:.*]]:2, %[[VAL_7:.*]], %[[VAL_8:.*]]:2 = sparse_tensor.disassemble %[[VAL_4]] : tensor<1x2x2x2xf64, #[[$demap]]>
+// CHECK:           return
+// CHECK:         }
+func.func @sparse_disassemble_reinterpret_map(%sp : tensor<2x4xf64, #BSR>,
+                                              %od : tensor<?xf64>,
+                                              %op : tensor<?xindex>,
+                                              %oi : tensor<?xindex>)
+                                            -> (tensor<?xf64>, tensor<?xindex>, tensor<?xindex>) {
+  %rd, %rp, %ri, %dl, %pl, %il = sparse_tensor.disassemble %sp : tensor<2x4xf64, #BSR>
+                                 outs(%od, %op, %oi : tensor<?xf64>, tensor<?xindex>, tensor<?xindex>)
+                                 -> tensor<?xf64>, (tensor<?xindex>, tensor<?xindex>), index, (index, index)
+  return %rd, %rp, %ri : tensor<?xf64>, tensor<?xindex>, tensor<?xindex>
+}

Copy link
Contributor

@aartbik aartbik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we doing the right thing for permutations too (like CSC)?

@PeimingLiu PeimingLiu merged commit 07bf1dd into llvm:main Feb 1, 2024
@PeimingLiu PeimingLiu deleted the reinterpret-ass branch February 1, 2024 23:11
agozillon pushed a commit to agozillon/llvm-project that referenced this pull request Feb 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants