Skip to content

[mlir][sparse] allow for direct-out passing of sparse tensor buffers #88327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ enum class SparseEmitStrategy {
// The SparseAssembler pass.
//===----------------------------------------------------------------------===//

void populateSparseAssembler(RewritePatternSet &patterns);
void populateSparseAssembler(RewritePatternSet &patterns, bool directOut);

std::unique_ptr<Pass> createSparseAssembler();
std::unique_ptr<Pass> createSparseAssembler(bool directOut);

//===----------------------------------------------------------------------===//
// The SparseReinterpretMap pass.
Expand Down
9 changes: 9 additions & 0 deletions mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,21 @@ def SparseAssembler : Pass<"sparse-assembler", "ModuleOp"> {
sparse tensors as numpy arrays from and to Python. Note that eventual
bufferization decisions (e.g. who [de]allocates the underlying memory)
should be resolved in agreement with the external runtime.

By default, the pass uses the [dis]assemble operations to input and output
sparse tensors. When the direct-out option is set, however, the output
directly returns the MLIR allocated buffers to the external runtime.
}];
let constructor = "mlir::createSparseAssembler()";
let dependentDialects = [
"bufferization::BufferizationDialect",
"sparse_tensor::SparseTensorDialect",
"tensor::TensorDialect",
];
let options = [
Option<"directOut", "direct-out", "bool",
"false", "Directly returns buffers externally">,
];
}

def SparseReinterpretMap : Pass<"sparse-reinterpret-map", "ModuleOp"> {
Expand Down
87 changes: 54 additions & 33 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "Utils/CodegenUtils.h"

#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
Expand All @@ -24,39 +25,41 @@ using namespace sparse_tensor;

// Convert type range to new types range, with sparse tensors externalized.
static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
SmallVectorImpl<Type> *extraTypes = nullptr) {
SmallVectorImpl<Type> *extraTypes, bool directOut) {
for (auto type : types) {
// All "dense" data passes through unmodified.
if (!getSparseTensorEncoding(type)) {
convTypes.push_back(type);
continue;
}

// Convert the external representation of the position/coordinate array
// Convert the external representations of the pos/crd/val arrays.
const SparseTensorType stt(cast<RankedTensorType>(type));
foreachFieldAndTypeInSparseTensor(stt, [&convTypes, extraTypes](
Type t, FieldIndex,
SparseTensorFieldKind kind,
Level, LevelType) {
if (kind == SparseTensorFieldKind::CrdMemRef ||
kind == SparseTensorFieldKind::PosMemRef ||
kind == SparseTensorFieldKind::ValMemRef) {
ShapedType st = t.cast<ShapedType>();
auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
convTypes.push_back(rtp);
if (extraTypes)
extraTypes->push_back(rtp);
}
return true;
});
foreachFieldAndTypeInSparseTensor(
stt, [&convTypes, extraTypes, directOut](Type t, FieldIndex,
SparseTensorFieldKind kind,
Level, LevelType) {
if (kind == SparseTensorFieldKind::PosMemRef ||
kind == SparseTensorFieldKind::CrdMemRef ||
kind == SparseTensorFieldKind::ValMemRef) {
auto rtp = t.cast<ShapedType>();
if (!directOut) {
rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
if (extraTypes)
extraTypes->push_back(rtp);
}
convTypes.push_back(rtp);
}
return true;
});
}
}

// Convert input and output values to [dis]assemble ops for sparse tensors.
static void convVals(OpBuilder &builder, Location loc, TypeRange types,
ValueRange fromVals, ValueRange extraVals,
SmallVectorImpl<Value> &toVals, unsigned extra,
bool isIn) {
SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn,
bool directOut) {
unsigned idx = 0;
for (auto type : types) {
// All "dense" data passes through unmodified.
Expand All @@ -73,18 +76,29 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
if (!isIn)
inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble

// Collect the external representations of the pos/crd arrays.
// Collect the external representations of the pos/crd/val arrays.
foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
SparseTensorFieldKind kind,
Level, LevelType) {
if (kind == SparseTensorFieldKind::CrdMemRef ||
kind == SparseTensorFieldKind::PosMemRef ||
Level lv, LevelType) {
if (kind == SparseTensorFieldKind::PosMemRef ||
kind == SparseTensorFieldKind::CrdMemRef ||
kind == SparseTensorFieldKind::ValMemRef) {
if (isIn) {
inputs.push_back(fromVals[idx++]);
} else if (directOut) {
Value mem;
if (kind == SparseTensorFieldKind::PosMemRef)
mem = builder.create<sparse_tensor::ToPositionsOp>(loc, inputs[0],
lv);
else if (kind == SparseTensorFieldKind::CrdMemRef)
mem = builder.create<sparse_tensor::ToCoordinatesOp>(loc, inputs[0],
lv);
else
mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]);
toVals.push_back(mem);
} else {
ShapedType st = t.cast<ShapedType>();
auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
ShapedType rtp = t.cast<ShapedType>();
rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
inputs.push_back(extraVals[extra++]);
retTypes.push_back(rtp);
cntTypes.push_back(builder.getIndexType());
Expand All @@ -97,7 +111,7 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
// Assemble multiple inputs into a single sparse tensor.
auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs);
toVals.push_back(a.getResult());
} else {
} else if (!directOut) {
// Disassemble a single sparse input into multiple outputs.
// Note that this includes the counters, which are dropped.
unsigned len = retTypes.size();
Expand Down Expand Up @@ -144,11 +158,14 @@ namespace {
// return ..., t1..tn, ...
// }
//
// TODO: refine output sparse tensors to work well with external framework
// (with a direct-out variant without the disassemble).
//
struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
using OpRewritePattern::OpRewritePattern;

SparseFuncAssembler(MLIRContext *context, bool dO)
: OpRewritePattern(context), directOut(dO) {}

LogicalResult matchAndRewrite(func::FuncOp funcOp,
PatternRewriter &rewriter) const override {
// Only rewrite public entry methods.
Expand All @@ -159,8 +176,8 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
SmallVector<Type> inputTypes;
SmallVector<Type> outputTypes;
SmallVector<Type> extraTypes;
convTypes(funcOp.getArgumentTypes(), inputTypes);
convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes);
convTypes(funcOp.getArgumentTypes(), inputTypes, nullptr, false);
convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes, directOut);

// Only sparse inputs or outputs need a wrapper method.
if (inputTypes.size() == funcOp.getArgumentTypes().size() &&
Expand Down Expand Up @@ -192,7 +209,7 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
// Convert inputs.
SmallVector<Value> inputs;
convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
ValueRange(), inputs, 0, /*isIn=*/true);
ValueRange(), inputs, /*extra=*/0, /*isIn=*/true, directOut);

// Call the original, now private method. A subsequent inlining pass can
// determine whether cloning the method body in place is worthwhile.
Expand All @@ -203,7 +220,7 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
// Convert outputs and return.
SmallVector<Value> outputs;
convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(),
body->getArguments(), outputs, extra, /*isIn=*/false);
body->getArguments(), outputs, extra, /*isIn=*/false, directOut);
rewriter.create<func::ReturnOp>(loc, outputs);

// Finally, migrate a potential c-interface property.
Expand All @@ -215,6 +232,9 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
}
return success();
}

private:
const bool directOut;
};

} // namespace
Expand All @@ -223,6 +243,7 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
// Public method for populating conversion rules.
//===----------------------------------------------------------------------===//

void mlir::populateSparseAssembler(RewritePatternSet &patterns) {
patterns.add<SparseFuncAssembler>(patterns.getContext());
void mlir::populateSparseAssembler(RewritePatternSet &patterns,
bool directOut) {
patterns.add<SparseFuncAssembler>(patterns.getContext(), directOut);
}
Original file line number Diff line number Diff line change
Expand Up @@ -767,16 +767,19 @@ class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
};

/// Sparse conversion rule for the sparse_tensor.disassemble operator.
/// Note that the current implementation simply exposes the buffers to
/// the external client. This assumes the client only reads the buffers
/// (usually copying it to the external data structures, such as numpy
/// arrays). The semantics of the disassemble operation technically
/// require that the copying is done here already using the out-levels
/// and out-values clause.
class SparseTensorDisassembleConverter
: public OpConversionPattern<DisassembleOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// We simply expose the buffers to the external client. This
// assumes the client only reads the buffers (usually copying it
// to the external data structures, such as numpy arrays).
Location loc = op->getLoc();
auto stt = getSparseTensorType(op.getTensor());
SmallVector<Value> retVal;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,12 @@ namespace {
struct SparseAssembler : public impl::SparseAssemblerBase<SparseAssembler> {
SparseAssembler() = default;
SparseAssembler(const SparseAssembler &pass) = default;
SparseAssembler(bool dO) { directOut = dO; }

void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateSparseAssembler(patterns);
populateSparseAssembler(patterns, directOut);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
Expand Down
52 changes: 52 additions & 0 deletions mlir/test/Dialect/SparseTensor/external_direct.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// RUN: mlir-opt %s --sparse-assembler="direct-out=True" -split-input-file | FileCheck %s

// -----

// CHECK-LABEL: func.func @sparse_in(
// CHECK-SAME: %[[B:.*0]]: tensor<?xindex>,
// CHECK-SAME: %[[C:.*1]]: tensor<?xindex>,
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>) -> tensor<64x64xf32> {
// CHECK: %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
// CHECK: %[[F:.*]] = call @_internal_sparse_in(%[[I]])
// CHECK: return %[[F]] : tensor<64x64xf32>
// CHECK: }
// CHECK: func.func private @_internal_sparse_in
#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
func.func @sparse_in(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> {
%0 = sparse_tensor.convert %arg0 : tensor<64x64xf32, #sparse> to tensor<64x64xf32>
return %0 : tensor<64x64xf32>
}

// -----

// CHECK-LABEL: func.func @sparse_out(
// CHECK-SAME: %[[X:.*0]]: tensor<64x64xf32>)
// CHECK: %[[F:.*]] = call @_internal_sparse_out(%[[X]])
// CHECK: %[[P:.*]] = sparse_tensor.positions %[[F]]
// CHECK: %[[C:.*]] = sparse_tensor.coordinates %[[F]]
// CHECK: %[[V:.*]] = sparse_tensor.values %[[F]]
// CHECK: return %[[P]], %[[C]], %[[V]]
// CHECK: }
// CHECK: func.func private @_internal_sparse_out
#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
func.func @sparse_out(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32, #sparse> {
%0 = sparse_tensor.convert %arg0 : tensor<64x64xf32> to tensor<64x64xf32, #sparse>
return %0 : tensor<64x64xf32, #sparse>
}

// -----

// CHECK-LABEL: func.func @sparse_out2(
// CHECK-SAME: %[[X:.*0]]: tensor<64x64xf32>)
// CHECK: %[[F:.*]]:2 = call @_internal_sparse_out2(%[[X]])
// CHECK: %[[P:.*]] = sparse_tensor.positions %[[F]]#1
// CHECK: %[[C:.*]] = sparse_tensor.coordinates %[[F]]#1
// CHECK: %[[V:.*]] = sparse_tensor.values %[[F]]#1
// CHECK: return %[[F]]#0, %[[P]], %[[C]], %[[V]]
// CHECK: }
// CHECK: func.func private @_internal_sparse_out2
#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64x64xf32, #sparse>) {
%0 = sparse_tensor.convert %arg0 : tensor<64x64xf32> to tensor<64x64xf32, #sparse>
return %arg0, %0 : tensor<64x64xf32>, tensor<64x64xf32, #sparse>
}