8
8
9
9
#include " Utils/CodegenUtils.h"
10
10
11
+ #include " mlir/Dialect/Bufferization/IR/Bufferization.h"
11
12
#include " mlir/Dialect/SparseTensor/IR/SparseTensor.h"
12
13
#include " mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
13
14
#include " mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
@@ -24,39 +25,41 @@ using namespace sparse_tensor;
24
25
25
26
// Convert type range to new types range, with sparse tensors externalized.
26
27
static void convTypes (TypeRange types, SmallVectorImpl<Type> &convTypes,
27
- SmallVectorImpl<Type> *extraTypes = nullptr ) {
28
+ SmallVectorImpl<Type> *extraTypes, bool directOut ) {
28
29
for (auto type : types) {
29
30
// All "dense" data passes through unmodified.
30
31
if (!getSparseTensorEncoding (type)) {
31
32
convTypes.push_back (type);
32
33
continue ;
33
34
}
34
35
35
- // Convert the external representation of the position/coordinate array
36
+ // Convert the external representations of the pos/crd/val arrays.
36
37
const SparseTensorType stt (cast<RankedTensorType>(type));
37
- foreachFieldAndTypeInSparseTensor (stt, [&convTypes, extraTypes](
38
- Type t, FieldIndex,
39
- SparseTensorFieldKind kind,
40
- Level, LevelType) {
41
- if (kind == SparseTensorFieldKind::CrdMemRef ||
42
- kind == SparseTensorFieldKind::PosMemRef ||
43
- kind == SparseTensorFieldKind::ValMemRef) {
44
- ShapedType st = t.cast <ShapedType>();
45
- auto rtp = RankedTensorType::get (st.getShape (), st.getElementType ());
46
- convTypes.push_back (rtp);
47
- if (extraTypes)
48
- extraTypes->push_back (rtp);
49
- }
50
- return true ;
51
- });
38
+ foreachFieldAndTypeInSparseTensor (
39
+ stt, [&convTypes, extraTypes, directOut](Type t, FieldIndex,
40
+ SparseTensorFieldKind kind,
41
+ Level, LevelType) {
42
+ if (kind == SparseTensorFieldKind::PosMemRef ||
43
+ kind == SparseTensorFieldKind::CrdMemRef ||
44
+ kind == SparseTensorFieldKind::ValMemRef) {
45
+ auto rtp = t.cast <ShapedType>();
46
+ if (!directOut) {
47
+ rtp = RankedTensorType::get (rtp.getShape (), rtp.getElementType ());
48
+ if (extraTypes)
49
+ extraTypes->push_back (rtp);
50
+ }
51
+ convTypes.push_back (rtp);
52
+ }
53
+ return true ;
54
+ });
52
55
}
53
56
}
54
57
55
58
// Convert input and output values to [dis]assemble ops for sparse tensors.
56
59
static void convVals (OpBuilder &builder, Location loc, TypeRange types,
57
60
ValueRange fromVals, ValueRange extraVals,
58
- SmallVectorImpl<Value> &toVals, unsigned extra,
59
- bool isIn ) {
61
+ SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn,
62
+ bool directOut ) {
60
63
unsigned idx = 0 ;
61
64
for (auto type : types) {
62
65
// All "dense" data passes through unmodified.
@@ -73,18 +76,29 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
73
76
if (!isIn)
74
77
inputs.push_back (fromVals[idx++]); // The sparse tensor to disassemble
75
78
76
- // Collect the external representations of the pos/crd arrays.
79
+ // Collect the external representations of the pos/crd/val arrays.
77
80
foreachFieldAndTypeInSparseTensor (stt, [&, isIn](Type t, FieldIndex,
78
81
SparseTensorFieldKind kind,
79
- Level, LevelType) {
80
- if (kind == SparseTensorFieldKind::CrdMemRef ||
81
- kind == SparseTensorFieldKind::PosMemRef ||
82
+ Level lv , LevelType) {
83
+ if (kind == SparseTensorFieldKind::PosMemRef ||
84
+ kind == SparseTensorFieldKind::CrdMemRef ||
82
85
kind == SparseTensorFieldKind::ValMemRef) {
83
86
if (isIn) {
84
87
inputs.push_back (fromVals[idx++]);
88
+ } else if (directOut) {
89
+ Value mem;
90
+ if (kind == SparseTensorFieldKind::PosMemRef)
91
+ mem = builder.create <sparse_tensor::ToPositionsOp>(loc, inputs[0 ],
92
+ lv);
93
+ else if (kind == SparseTensorFieldKind::CrdMemRef)
94
+ mem = builder.create <sparse_tensor::ToCoordinatesOp>(loc, inputs[0 ],
95
+ lv);
96
+ else
97
+ mem = builder.create <sparse_tensor::ToValuesOp>(loc, inputs[0 ]);
98
+ toVals.push_back (mem);
85
99
} else {
86
- ShapedType st = t.cast <ShapedType>();
87
- auto rtp = RankedTensorType::get (st .getShape (), st .getElementType ());
100
+ ShapedType rtp = t.cast <ShapedType>();
101
+ rtp = RankedTensorType::get (rtp .getShape (), rtp .getElementType ());
88
102
inputs.push_back (extraVals[extra++]);
89
103
retTypes.push_back (rtp);
90
104
cntTypes.push_back (builder.getIndexType ());
@@ -97,7 +111,7 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
97
111
// Assemble multiple inputs into a single sparse tensor.
98
112
auto a = builder.create <sparse_tensor::AssembleOp>(loc, rtp, inputs);
99
113
toVals.push_back (a.getResult ());
100
- } else {
114
+ } else if (!directOut) {
101
115
// Disassemble a single sparse input into multiple outputs.
102
116
// Note that this includes the counters, which are dropped.
103
117
unsigned len = retTypes.size ();
@@ -144,11 +158,14 @@ namespace {
144
158
// return ..., t1..tn, ...
145
159
// }
146
160
//
147
- // TODO: refine output sparse tensors to work well with external framework
161
+ // (with a direct-out variant without the disassemble).
148
162
//
149
163
struct SparseFuncAssembler : public OpRewritePattern <func::FuncOp> {
150
164
using OpRewritePattern::OpRewritePattern;
151
165
166
+ SparseFuncAssembler (MLIRContext *context, bool dO)
167
+ : OpRewritePattern(context), directOut(dO) {}
168
+
152
169
LogicalResult matchAndRewrite (func::FuncOp funcOp,
153
170
PatternRewriter &rewriter) const override {
154
171
// Only rewrite public entry methods.
@@ -159,8 +176,8 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
159
176
SmallVector<Type> inputTypes;
160
177
SmallVector<Type> outputTypes;
161
178
SmallVector<Type> extraTypes;
162
- convTypes (funcOp.getArgumentTypes (), inputTypes);
163
- convTypes (funcOp.getResultTypes (), outputTypes, &extraTypes);
179
+ convTypes (funcOp.getArgumentTypes (), inputTypes, nullptr , false );
180
+ convTypes (funcOp.getResultTypes (), outputTypes, &extraTypes, directOut );
164
181
165
182
// Only sparse inputs or outputs need a wrapper method.
166
183
if (inputTypes.size () == funcOp.getArgumentTypes ().size () &&
@@ -192,7 +209,7 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
192
209
// Convert inputs.
193
210
SmallVector<Value> inputs;
194
211
convVals (rewriter, loc, funcOp.getArgumentTypes (), body->getArguments (),
195
- ValueRange (), inputs, 0 , /* isIn=*/ true );
212
+ ValueRange (), inputs, /* extra= */ 0 , /* isIn=*/ true , directOut );
196
213
197
214
// Call the original, now private method. A subsequent inlining pass can
198
215
// determine whether cloning the method body in place is worthwhile.
@@ -203,7 +220,7 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
203
220
// Convert outputs and return.
204
221
SmallVector<Value> outputs;
205
222
convVals (rewriter, loc, funcOp.getResultTypes (), call.getResults (),
206
- body->getArguments (), outputs, extra, /* isIn=*/ false );
223
+ body->getArguments (), outputs, extra, /* isIn=*/ false , directOut );
207
224
rewriter.create <func::ReturnOp>(loc, outputs);
208
225
209
226
// Finally, migrate a potential c-interface property.
@@ -215,6 +232,9 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
215
232
}
216
233
return success ();
217
234
}
235
+
236
+ private:
237
+ const bool directOut;
218
238
};
219
239
220
240
} // namespace
@@ -223,6 +243,7 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
223
243
// Public method for populating conversion rules.
224
244
// ===----------------------------------------------------------------------===//
225
245
226
- void mlir::populateSparseAssembler (RewritePatternSet &patterns) {
227
- patterns.add <SparseFuncAssembler>(patterns.getContext ());
246
+ void mlir::populateSparseAssembler (RewritePatternSet &patterns,
247
+ bool directOut) {
248
+ patterns.add <SparseFuncAssembler>(patterns.getContext (), directOut);
228
249
}
0 commit comments