Skip to content

[mlir][sparse] allow for direct-out passing of sparse tensor buffers #88327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 11, 2024

Conversation

aartbik
Copy link
Contributor

@aartbik aartbik commented Apr 10, 2024

In order to support various external frameworks (JAX vs PyTorch) we need a bit more flexibility in [dis]assembling external buffers to and from sparse tensors in MLIR land. This PR adds a direct-out option that avoids the rigid pre-allocated for copy-out semantics.

Note that over time, we expect the [dis]assemble operations to converge into something that supports all sorts of external frameworks. Until then, this option helps in experimenting with different options.

In order to support various external frameworks (JAX vs PyTorch)
we need a bit more flexibility in [dis]assembling external buffers
to and from sparse tensors in MLIR land. This PR adds a direct-out
option that avoids the rigid pre-allocated for copy-out semantics.

Note that over time, we expect the [dis]assemble operations to
converge into something that supports all sorts of external frameworks.
Until then, this option helps in experimenting with different options.
@llvmbot
Copy link
Member

llvmbot commented Apr 10, 2024

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Aart Bik (aartbik)

Changes

In order to support various external frameworks (JAX vs PyTorch) we need a bit more flexibility in [dis]assembling external buffers to and from sparse tensors in MLIR land. This PR adds a direct-out option that avoids the rigid pre-allocated for copy-out semantics.

Note that over time, we expect the [dis]assemble operations to converge into something that supports all sorts of external frameworks. Until then, this option helps in experimenting with different options.


Full diff: https://github.com/llvm/llvm-project/pull/88327.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h (+2-1)
  • (modified) mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td (+9)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp (+62-34)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp (+6-3)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp (+2-1)
  • (added) mlir/test/Dialect/SparseTensor/external_direct.mlir (+35)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 61b07d222d156b..d6d038ef65bdf4 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -60,9 +60,10 @@ enum class SparseEmitStrategy {
 // The SparseAssembler pass.
 //===----------------------------------------------------------------------===//
 
-void populateSparseAssembler(RewritePatternSet &patterns);
+void populateSparseAssembler(RewritePatternSet &patterns, bool directOut);
 
 std::unique_ptr<Pass> createSparseAssembler();
+std::unique_ptr<Pass> createSparseAssembler(bool directOut);
 
 //===----------------------------------------------------------------------===//
 // The SparseReinterpretMap pass.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 58e2d6f32386c3..4706d5ba2f218c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -23,12 +23,21 @@ def SparseAssembler : Pass<"sparse-assembler", "ModuleOp"> {
     sparse tensors as numpy arrays from and to Python. Note that eventual
     bufferization decisions (e.g. who [de]allocates the underlying memory)
     should be resolved in agreement with the external runtime.
+
+    By default, the pass uses the [dis]assemble operations to input and output
+    sparse tensors. When the direct-out option is set, however, the output
+    directly returns the MLIR allocated buffers to the external runtime.
   }];
   let constructor = "mlir::createSparseAssembler()";
   let dependentDialects = [
+    "bufferization::BufferizationDialect",
     "sparse_tensor::SparseTensorDialect",
     "tensor::TensorDialect",
   ];
+  let options = [
+    Option<"directOut", "direct-out", "bool",
+      "false", "Directly returns buffers externally">,
+  ];
 }
 
 def SparseReinterpretMap : Pass<"sparse-reinterpret-map", "ModuleOp"> {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index a91d32a23cac9f..a2edc75fc38c02 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -8,6 +8,7 @@
 
 #include "Utils/CodegenUtils.h"
 
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
@@ -24,7 +25,7 @@ using namespace sparse_tensor;
 
 // Convert type range to new types range, with sparse tensors externalized.
 static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
-                      SmallVectorImpl<Type> *extraTypes = nullptr) {
+                      SmallVectorImpl<Type> *extraTypes, bool directOut) {
   for (auto type : types) {
     // All "dense" data passes through unmodified.
     if (!getSparseTensorEncoding(type)) {
@@ -32,31 +33,38 @@ static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
       continue;
     }
 
-    // Convert the external representation of the position/coordinate array
+    // Convert the external representations of the pos/crd/val arrays.
     const SparseTensorType stt(cast<RankedTensorType>(type));
-    foreachFieldAndTypeInSparseTensor(stt, [&convTypes, extraTypes](
-                                               Type t, FieldIndex,
-                                               SparseTensorFieldKind kind,
-                                               Level, LevelType) {
-      if (kind == SparseTensorFieldKind::CrdMemRef ||
-          kind == SparseTensorFieldKind::PosMemRef ||
-          kind == SparseTensorFieldKind::ValMemRef) {
-        ShapedType st = t.cast<ShapedType>();
-        auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
-        convTypes.push_back(rtp);
-        if (extraTypes)
-          extraTypes->push_back(rtp);
-      }
-      return true;
-    });
+    foreachFieldAndTypeInSparseTensor(
+        stt, [&convTypes, extraTypes, directOut](Type t, FieldIndex,
+                                                 SparseTensorFieldKind kind,
+                                                 Level, LevelType) {
+          if (kind == SparseTensorFieldKind::PosMemRef ||
+              kind == SparseTensorFieldKind::CrdMemRef ||
+              kind == SparseTensorFieldKind::ValMemRef) {
+            auto st = t.cast<ShapedType>();
+            auto shape = st.getShape();
+            auto eltTp = st.getElementType();
+            Type rtp;
+            if (directOut) {
+              rtp = MemRefType::get(shape, eltTp);
+            } else {
+              rtp = RankedTensorType::get(shape, eltTp);
+              if (extraTypes)
+                extraTypes->push_back(rtp);
+            }
+            convTypes.push_back(rtp);
+          }
+          return true;
+        });
   }
 }
 
 // Convert input and output values to [dis]assemble ops for sparse tensors.
 static void convVals(OpBuilder &builder, Location loc, TypeRange types,
                      ValueRange fromVals, ValueRange extraVals,
-                     SmallVectorImpl<Value> &toVals, unsigned extra,
-                     bool isIn) {
+                     SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn,
+                     bool directOut) {
   unsigned idx = 0;
   for (auto type : types) {
     // All "dense" data passes through unmodified.
@@ -73,21 +81,34 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
     if (!isIn)
       inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble
 
-    // Collect the external representations of the pos/crd arrays.
+    // Collect the external representations of the pos/crd/val arrays.
     foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
                                                      SparseTensorFieldKind kind,
-                                                     Level, LevelType) {
-      if (kind == SparseTensorFieldKind::CrdMemRef ||
-          kind == SparseTensorFieldKind::PosMemRef ||
+                                                     Level lv, LevelType) {
+      if (kind == SparseTensorFieldKind::PosMemRef ||
+          kind == SparseTensorFieldKind::CrdMemRef ||
           kind == SparseTensorFieldKind::ValMemRef) {
         if (isIn) {
           inputs.push_back(fromVals[idx++]);
         } else {
           ShapedType st = t.cast<ShapedType>();
           auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
-          inputs.push_back(extraVals[extra++]);
-          retTypes.push_back(rtp);
-          cntTypes.push_back(builder.getIndexType());
+          if (directOut) {
+            Value mem;
+            if (kind == SparseTensorFieldKind::PosMemRef)
+              mem = builder.create<sparse_tensor::ToPositionsOp>(loc, inputs[0],
+                                                                 lv);
+            else if (kind == SparseTensorFieldKind::CrdMemRef)
+              mem = builder.create<sparse_tensor::ToCoordinatesOp>(
+                  loc, inputs[0], lv);
+            else
+              mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]);
+            toVals.push_back(mem);
+          } else {
+            inputs.push_back(extraVals[extra++]);
+            retTypes.push_back(rtp);
+            cntTypes.push_back(builder.getIndexType());
+          }
         }
       }
       return true;
@@ -97,7 +118,7 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
       // Assemble multiple inputs into a single sparse tensor.
       auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs);
       toVals.push_back(a.getResult());
-    } else {
+    } else if (!directOut) {
       // Disassemble a single sparse input into multiple outputs.
       // Note that this includes the counters, which are dropped.
       unsigned len = retTypes.size();
@@ -144,11 +165,14 @@ namespace {
 //   return ..., t1..tn, ...
 // }
 //
-// TODO: refine output sparse tensors to work well with external framework
+// (with a direct-out variant without the disassemble).
 //
 struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
   using OpRewritePattern::OpRewritePattern;
 
+  SparseFuncAssembler(MLIRContext *context, bool dO)
+      : OpRewritePattern(context), directOut(dO) {}
+
   LogicalResult matchAndRewrite(func::FuncOp funcOp,
                                 PatternRewriter &rewriter) const override {
     // Only rewrite public entry methods.
@@ -159,8 +183,8 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
     SmallVector<Type> inputTypes;
     SmallVector<Type> outputTypes;
     SmallVector<Type> extraTypes;
-    convTypes(funcOp.getArgumentTypes(), inputTypes);
-    convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes);
+    convTypes(funcOp.getArgumentTypes(), inputTypes, nullptr, directOut);
+    convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes, directOut);
 
     // Only sparse inputs or outputs need a wrapper method.
     if (inputTypes.size() == funcOp.getArgumentTypes().size() &&
@@ -192,7 +216,7 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
     // Convert inputs.
     SmallVector<Value> inputs;
     convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
-             ValueRange(), inputs, 0, /*isIn=*/true);
+             ValueRange(), inputs, /*extra=*/0, /*isIn=*/true, directOut);
 
     // Call the original, now private method. A subsequent inlining pass can
     // determine whether cloning the method body in place is worthwhile.
@@ -203,7 +227,7 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
     // Convert outputs and return.
     SmallVector<Value> outputs;
     convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(),
-             body->getArguments(), outputs, extra, /*isIn=*/false);
+             body->getArguments(), outputs, extra, /*isIn=*/false, directOut);
     rewriter.create<func::ReturnOp>(loc, outputs);
 
     // Finally, migrate a potential c-interface property.
@@ -215,6 +239,9 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
     }
     return success();
   }
+
+private:
+  const bool directOut;
 };
 
 } // namespace
@@ -223,6 +250,7 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
 // Public method for populating conversion rules.
 //===----------------------------------------------------------------------===//
 
-void mlir::populateSparseAssembler(RewritePatternSet &patterns) {
-  patterns.add<SparseFuncAssembler>(patterns.getContext());
+void mlir::populateSparseAssembler(RewritePatternSet &patterns,
+                                   bool directOut) {
+  patterns.add<SparseFuncAssembler>(patterns.getContext(), directOut);
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index c52fa3751e6b4a..f0d162bdb84d96 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -767,6 +767,12 @@ class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
 };
 
 /// Sparse conversion rule for the sparse_tensor.disassemble operator.
+/// Note that the current implementation simply exposes the buffers to
+/// the external client. This assumes the client only reads the buffers
+/// (usually copying it to the external data structures, such as numpy
+/// arrays). The semantics of the disassemble operation technically
+/// require that the copying is done here already using the out-levels
+/// and out-values clause.
 class SparseTensorDisassembleConverter
     : public OpConversionPattern<DisassembleOp> {
 public:
@@ -774,9 +780,6 @@ class SparseTensorDisassembleConverter
   LogicalResult
   matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    // We simply expose the buffers to the external client. This
-    // assumes the client only reads the buffers (usually copying it
-    // to the external data structures, such as numpy arrays).
     Location loc = op->getLoc();
     auto stt = getSparseTensorType(op.getTensor());
     SmallVector<Value> retVal;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index acea25f023980a..b42d58634a36c4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -50,11 +50,12 @@ namespace {
 struct SparseAssembler : public impl::SparseAssemblerBase<SparseAssembler> {
   SparseAssembler() = default;
   SparseAssembler(const SparseAssembler &pass) = default;
+  SparseAssembler(bool dO) { directOut = dO; }
 
   void runOnOperation() override {
     auto *ctx = &getContext();
     RewritePatternSet patterns(ctx);
-    populateSparseAssembler(patterns);
+    populateSparseAssembler(patterns, directOut);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };
diff --git a/mlir/test/Dialect/SparseTensor/external_direct.mlir b/mlir/test/Dialect/SparseTensor/external_direct.mlir
new file mode 100644
index 00000000000000..97a6d3031d90cd
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/external_direct.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-opt %s --sparse-assembler="direct-out=True" -split-input-file | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: func.func @sparse_out(
+// CHECK-SAME:    %[[X:.*0]]: tensor<64x64xf32>)
+// CHECK:         %[[F:.*]] = call @_internal_sparse_out(%[[X]])
+// CHECK:         %[[P:.*]] = sparse_tensor.positions %[[F]]
+// CHECK:         %[[C:.*]] = sparse_tensor.coordinates %[[F]]
+// CHECK:         %[[V:.*]] = sparse_tensor.values %[[F]]
+// CHECK:         return %[[P]], %[[C]], %[[V]]
+// CHECK:       }
+// CHECK:       func.func private @_internal_sparse_out
+#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
+func.func @sparse_out(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32, #sparse> {
+  %0 = sparse_tensor.convert %arg0 : tensor<64x64xf32> to tensor<64x64xf32, #sparse>
+  return %0 : tensor<64x64xf32, #sparse>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @sparse_out2(
+// CHECK-SAME:    %[[X:.*0]]: tensor<64x64xf32>)
+// CHECK:         %[[F:.*]]:2 = call @_internal_sparse_out2(%[[X]])
+// CHECK:         %[[P:.*]] = sparse_tensor.positions %[[F]]#1
+// CHECK:         %[[C:.*]] = sparse_tensor.coordinates %[[F]]#1
+// CHECK:         %[[V:.*]] = sparse_tensor.values %[[F]]#1
+// CHECK:         return %[[F]]#0, %[[P]], %[[C]], %[[V]]
+// CHECK:       }
+// CHECK:       func.func private @_internal_sparse_out2
+#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
+func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64x64xf32, #sparse>) {
+  %0 = sparse_tensor.convert %arg0 : tensor<64x64xf32> to tensor<64x64xf32, #sparse>
+  return %arg0, %0 : tensor<64x64xf32>, tensor<64x64xf32, #sparse>
+}

@aartbik aartbik requested a review from PeimingLiu April 11, 2024 01:50
@aartbik aartbik merged commit 5122a2c into llvm:main Apr 11, 2024
@aartbik aartbik deleted the bik branch April 11, 2024 17:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants