Skip to content

Commit 5122a2c

Browse files
authored
[mlir][sparse] allow for direct-out passing of sparse tensor buffers (#88327)
In order to support various external frameworks (JAX vs PyTorch) we need a bit more flexibility in [dis]assembling external buffers to and from sparse tensors in MLIR land. This PR adds a direct-out option that avoids the rigid pre-allocated for copy-out semantics. Note that over time, we expect the [dis]assemble operations to converge into something that supports all sorts of external frameworks. Until then, this option helps in experimenting with different options.
1 parent 64c3997 commit 5122a2c

File tree

6 files changed

+125
-38
lines changed

6 files changed

+125
-38
lines changed

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,10 @@ enum class SparseEmitStrategy {
6060
// The SparseAssembler pass.
6161
//===----------------------------------------------------------------------===//
6262

63-
void populateSparseAssembler(RewritePatternSet &patterns);
63+
void populateSparseAssembler(RewritePatternSet &patterns, bool directOut);
6464

6565
std::unique_ptr<Pass> createSparseAssembler();
66+
std::unique_ptr<Pass> createSparseAssembler(bool directOut);
6667

6768
//===----------------------------------------------------------------------===//
6869
// The SparseReinterpretMap pass.

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,21 @@ def SparseAssembler : Pass<"sparse-assembler", "ModuleOp"> {
2323
sparse tensors as numpy arrays from and to Python. Note that eventual
2424
bufferization decisions (e.g. who [de]allocates the underlying memory)
2525
should be resolved in agreement with the external runtime.
26+
27+
By default, the pass uses the [dis]assemble operations to input and output
28+
sparse tensors. When the direct-out option is set, however, the output
29+
directly returns the MLIR allocated buffers to the external runtime.
2630
}];
2731
let constructor = "mlir::createSparseAssembler()";
2832
let dependentDialects = [
33+
"bufferization::BufferizationDialect",
2934
"sparse_tensor::SparseTensorDialect",
3035
"tensor::TensorDialect",
3136
];
37+
let options = [
38+
Option<"directOut", "direct-out", "bool",
39+
"false", "Directly returns buffers externally">,
40+
];
3241
}
3342

3443
def SparseReinterpretMap : Pass<"sparse-reinterpret-map", "ModuleOp"> {

mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp

Lines changed: 54 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "Utils/CodegenUtils.h"
1010

11+
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1112
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
1213
#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
1314
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
@@ -24,39 +25,41 @@ using namespace sparse_tensor;
2425

2526
// Convert type range to new types range, with sparse tensors externalized.
2627
static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
27-
SmallVectorImpl<Type> *extraTypes = nullptr) {
28+
SmallVectorImpl<Type> *extraTypes, bool directOut) {
2829
for (auto type : types) {
2930
// All "dense" data passes through unmodified.
3031
if (!getSparseTensorEncoding(type)) {
3132
convTypes.push_back(type);
3233
continue;
3334
}
3435

35-
// Convert the external representation of the position/coordinate array
36+
// Convert the external representations of the pos/crd/val arrays.
3637
const SparseTensorType stt(cast<RankedTensorType>(type));
37-
foreachFieldAndTypeInSparseTensor(stt, [&convTypes, extraTypes](
38-
Type t, FieldIndex,
39-
SparseTensorFieldKind kind,
40-
Level, LevelType) {
41-
if (kind == SparseTensorFieldKind::CrdMemRef ||
42-
kind == SparseTensorFieldKind::PosMemRef ||
43-
kind == SparseTensorFieldKind::ValMemRef) {
44-
ShapedType st = t.cast<ShapedType>();
45-
auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
46-
convTypes.push_back(rtp);
47-
if (extraTypes)
48-
extraTypes->push_back(rtp);
49-
}
50-
return true;
51-
});
38+
foreachFieldAndTypeInSparseTensor(
39+
stt, [&convTypes, extraTypes, directOut](Type t, FieldIndex,
40+
SparseTensorFieldKind kind,
41+
Level, LevelType) {
42+
if (kind == SparseTensorFieldKind::PosMemRef ||
43+
kind == SparseTensorFieldKind::CrdMemRef ||
44+
kind == SparseTensorFieldKind::ValMemRef) {
45+
auto rtp = t.cast<ShapedType>();
46+
if (!directOut) {
47+
rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
48+
if (extraTypes)
49+
extraTypes->push_back(rtp);
50+
}
51+
convTypes.push_back(rtp);
52+
}
53+
return true;
54+
});
5255
}
5356
}
5457

5558
// Convert input and output values to [dis]assemble ops for sparse tensors.
5659
static void convVals(OpBuilder &builder, Location loc, TypeRange types,
5760
ValueRange fromVals, ValueRange extraVals,
58-
SmallVectorImpl<Value> &toVals, unsigned extra,
59-
bool isIn) {
61+
SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn,
62+
bool directOut) {
6063
unsigned idx = 0;
6164
for (auto type : types) {
6265
// All "dense" data passes through unmodified.
@@ -73,18 +76,29 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
7376
if (!isIn)
7477
inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble
7578

76-
// Collect the external representations of the pos/crd arrays.
79+
// Collect the external representations of the pos/crd/val arrays.
7780
foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
7881
SparseTensorFieldKind kind,
79-
Level, LevelType) {
80-
if (kind == SparseTensorFieldKind::CrdMemRef ||
81-
kind == SparseTensorFieldKind::PosMemRef ||
82+
Level lv, LevelType) {
83+
if (kind == SparseTensorFieldKind::PosMemRef ||
84+
kind == SparseTensorFieldKind::CrdMemRef ||
8285
kind == SparseTensorFieldKind::ValMemRef) {
8386
if (isIn) {
8487
inputs.push_back(fromVals[idx++]);
88+
} else if (directOut) {
89+
Value mem;
90+
if (kind == SparseTensorFieldKind::PosMemRef)
91+
mem = builder.create<sparse_tensor::ToPositionsOp>(loc, inputs[0],
92+
lv);
93+
else if (kind == SparseTensorFieldKind::CrdMemRef)
94+
mem = builder.create<sparse_tensor::ToCoordinatesOp>(loc, inputs[0],
95+
lv);
96+
else
97+
mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]);
98+
toVals.push_back(mem);
8599
} else {
86-
ShapedType st = t.cast<ShapedType>();
87-
auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
100+
ShapedType rtp = t.cast<ShapedType>();
101+
rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
88102
inputs.push_back(extraVals[extra++]);
89103
retTypes.push_back(rtp);
90104
cntTypes.push_back(builder.getIndexType());
@@ -97,7 +111,7 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
97111
// Assemble multiple inputs into a single sparse tensor.
98112
auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs);
99113
toVals.push_back(a.getResult());
100-
} else {
114+
} else if (!directOut) {
101115
// Disassemble a single sparse input into multiple outputs.
102116
// Note that this includes the counters, which are dropped.
103117
unsigned len = retTypes.size();
@@ -144,11 +158,14 @@ namespace {
144158
// return ..., t1..tn, ...
145159
// }
146160
//
147-
// TODO: refine output sparse tensors to work well with external framework
161+
// (with a direct-out variant without the disassemble).
148162
//
149163
struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
150164
using OpRewritePattern::OpRewritePattern;
151165

166+
SparseFuncAssembler(MLIRContext *context, bool dO)
167+
: OpRewritePattern(context), directOut(dO) {}
168+
152169
LogicalResult matchAndRewrite(func::FuncOp funcOp,
153170
PatternRewriter &rewriter) const override {
154171
// Only rewrite public entry methods.
@@ -159,8 +176,8 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
159176
SmallVector<Type> inputTypes;
160177
SmallVector<Type> outputTypes;
161178
SmallVector<Type> extraTypes;
162-
convTypes(funcOp.getArgumentTypes(), inputTypes);
163-
convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes);
179+
convTypes(funcOp.getArgumentTypes(), inputTypes, nullptr, false);
180+
convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes, directOut);
164181

165182
// Only sparse inputs or outputs need a wrapper method.
166183
if (inputTypes.size() == funcOp.getArgumentTypes().size() &&
@@ -192,7 +209,7 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
192209
// Convert inputs.
193210
SmallVector<Value> inputs;
194211
convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
195-
ValueRange(), inputs, 0, /*isIn=*/true);
212+
ValueRange(), inputs, /*extra=*/0, /*isIn=*/true, directOut);
196213

197214
// Call the original, now private method. A subsequent inlining pass can
198215
// determine whether cloning the method body in place is worthwhile.
@@ -203,7 +220,7 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
203220
// Convert outputs and return.
204221
SmallVector<Value> outputs;
205222
convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(),
206-
body->getArguments(), outputs, extra, /*isIn=*/false);
223+
body->getArguments(), outputs, extra, /*isIn=*/false, directOut);
207224
rewriter.create<func::ReturnOp>(loc, outputs);
208225

209226
// Finally, migrate a potential c-interface property.
@@ -215,6 +232,9 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
215232
}
216233
return success();
217234
}
235+
236+
private:
237+
const bool directOut;
218238
};
219239

220240
} // namespace
@@ -223,6 +243,7 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
223243
// Public method for populating conversion rules.
224244
//===----------------------------------------------------------------------===//
225245

226-
void mlir::populateSparseAssembler(RewritePatternSet &patterns) {
227-
patterns.add<SparseFuncAssembler>(patterns.getContext());
246+
void mlir::populateSparseAssembler(RewritePatternSet &patterns,
247+
bool directOut) {
248+
patterns.add<SparseFuncAssembler>(patterns.getContext(), directOut);
228249
}

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -767,16 +767,19 @@ class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
767767
};
768768

769769
/// Sparse conversion rule for the sparse_tensor.disassemble operator.
770+
/// Note that the current implementation simply exposes the buffers to
771+
/// the external client. This assumes the client only reads the buffers
772+
/// (usually copying it to the external data structures, such as numpy
773+
/// arrays). The semantics of the disassemble operation technically
774+
/// require that the copying is done here already using the out-levels
775+
/// and out-values clause.
770776
class SparseTensorDisassembleConverter
771777
: public OpConversionPattern<DisassembleOp> {
772778
public:
773779
using OpConversionPattern::OpConversionPattern;
774780
LogicalResult
775781
matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
776782
ConversionPatternRewriter &rewriter) const override {
777-
// We simply expose the buffers to the external client. This
778-
// assumes the client only reads the buffers (usually copying it
779-
// to the external data structures, such as numpy arrays).
780783
Location loc = op->getLoc();
781784
auto stt = getSparseTensorType(op.getTensor());
782785
SmallVector<Value> retVal;

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,12 @@ namespace {
5050
struct SparseAssembler : public impl::SparseAssemblerBase<SparseAssembler> {
5151
SparseAssembler() = default;
5252
SparseAssembler(const SparseAssembler &pass) = default;
53+
SparseAssembler(bool dO) { directOut = dO; }
5354

5455
void runOnOperation() override {
5556
auto *ctx = &getContext();
5657
RewritePatternSet patterns(ctx);
57-
populateSparseAssembler(patterns);
58+
populateSparseAssembler(patterns, directOut);
5859
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
5960
}
6061
};
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// RUN: mlir-opt %s --sparse-assembler="direct-out=True" -split-input-file | FileCheck %s
2+
3+
// -----
4+
5+
// CHECK-LABEL: func.func @sparse_in(
6+
// CHECK-SAME: %[[B:.*0]]: tensor<?xindex>,
7+
// CHECK-SAME: %[[C:.*1]]: tensor<?xindex>,
8+
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>) -> tensor<64x64xf32> {
9+
// CHECK: %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
10+
// CHECK: %[[F:.*]] = call @_internal_sparse_in(%[[I]])
11+
// CHECK: return %[[F]] : tensor<64x64xf32>
12+
// CHECK: }
13+
// CHECK: func.func private @_internal_sparse_in
14+
#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
15+
func.func @sparse_in(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> {
16+
%0 = sparse_tensor.convert %arg0 : tensor<64x64xf32, #sparse> to tensor<64x64xf32>
17+
return %0 : tensor<64x64xf32>
18+
}
19+
20+
// -----
21+
22+
// CHECK-LABEL: func.func @sparse_out(
23+
// CHECK-SAME: %[[X:.*0]]: tensor<64x64xf32>)
24+
// CHECK: %[[F:.*]] = call @_internal_sparse_out(%[[X]])
25+
// CHECK: %[[P:.*]] = sparse_tensor.positions %[[F]]
26+
// CHECK: %[[C:.*]] = sparse_tensor.coordinates %[[F]]
27+
// CHECK: %[[V:.*]] = sparse_tensor.values %[[F]]
28+
// CHECK: return %[[P]], %[[C]], %[[V]]
29+
// CHECK: }
30+
// CHECK: func.func private @_internal_sparse_out
31+
#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
32+
func.func @sparse_out(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32, #sparse> {
33+
%0 = sparse_tensor.convert %arg0 : tensor<64x64xf32> to tensor<64x64xf32, #sparse>
34+
return %0 : tensor<64x64xf32, #sparse>
35+
}
36+
37+
// -----
38+
39+
// CHECK-LABEL: func.func @sparse_out2(
40+
// CHECK-SAME: %[[X:.*0]]: tensor<64x64xf32>)
41+
// CHECK: %[[F:.*]]:2 = call @_internal_sparse_out2(%[[X]])
42+
// CHECK: %[[P:.*]] = sparse_tensor.positions %[[F]]#1
43+
// CHECK: %[[C:.*]] = sparse_tensor.coordinates %[[F]]#1
44+
// CHECK: %[[V:.*]] = sparse_tensor.values %[[F]]#1
45+
// CHECK: return %[[F]]#0, %[[P]], %[[C]], %[[V]]
46+
// CHECK: }
47+
// CHECK: func.func private @_internal_sparse_out2
48+
#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
49+
func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64x64xf32, #sparse>) {
50+
%0 = sparse_tensor.convert %arg0 : tensor<64x64xf32> to tensor<64x64xf32, #sparse>
51+
return %arg0, %0 : tensor<64x64xf32>, tensor<64x64xf32, #sparse>
52+
}

0 commit comments

Comments
 (0)