18
18
using namespace mlir ;
19
19
using namespace mlir ::sparse_tensor;
20
20
21
- namespace {
22
-
23
21
// ===----------------------------------------------------------------------===//
24
- // Helper methods.
22
+ // File Local Helper methods.
25
23
// ===----------------------------------------------------------------------===//
26
24
27
25
// Translates a "simple" map according to an identity lvl-map.
@@ -51,6 +49,27 @@ static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
51
49
return builder.create <ReinterpretMapOp>(val.getLoc (), enc, val);
52
50
}
53
51
52
+ static SmallVector<Value> remapValueRange (OpBuilder &rewriter, TypeRange types,
53
+ ValueRange outs) {
54
+ SmallVector<Value> ret (outs);
55
+ assert (outs.size () == types.size ());
56
+ for (auto [r, t] : llvm::zip (ret, types))
57
+ if (r.getType () != t)
58
+ r = rewriter.create <ReinterpretMapOp>(r.getLoc (), t, r);
59
+ return ret;
60
+ }
61
+
62
+ // / Whether the operation has any sparse tensor with non-identity dim2lvl maps.
63
+ static bool hasNonIdentityOperandsOrResults (Operation *op) {
64
+ auto hasNonIdentityMap = [](Value v) {
65
+ auto stt = tryGetSparseTensorType (v);
66
+ return stt && !stt->isIdentity ();
67
+ };
68
+
69
+ return llvm::any_of (op->getOperands (), hasNonIdentityMap) ||
70
+ llvm::any_of (op->getResults (), hasNonIdentityMap);
71
+ }
72
+
54
73
// Generates a clone of the given linalg generic operation, but with
55
74
// remapped arguments, index maps, and iteration types.
56
75
//
@@ -86,6 +105,8 @@ static linalg::GenericOp genGenericLinalg(PatternRewriter &rewriter,
86
105
return newOp;
87
106
}
88
107
108
+ namespace {
109
+
89
110
// ===----------------------------------------------------------------------===//
90
111
// Rewriting rules for linalg generic ops.
91
112
// ===----------------------------------------------------------------------===//
@@ -142,21 +163,17 @@ struct GenericOpReinterpretMap : public OpRewritePattern<linalg::GenericOp> {
142
163
};
143
164
144
165
// ===----------------------------------------------------------------------===//
145
- // Rewriting rules for operations other than linalg generic ops.
166
+ // Reinterpret Map Rewriters for operations other than linalg.generics
146
167
// ===----------------------------------------------------------------------===//
147
168
148
- // CRTP to help implementing a rewriter that demaps all its inputs and remaps
149
- // all its outputs.
169
+ // CRTP to help implementing a rewriter that demaps all its inputs.
150
170
template <typename SubClass, typename SourceOp>
151
- struct DemapInsRemapOutsRewriter : public OpRewritePattern <SourceOp> {
171
+ struct DemapInsRewriter : public OpRewritePattern <SourceOp> {
152
172
using OpRewritePattern<SourceOp>::OpRewritePattern;
153
173
using OpAdaptor = typename SourceOp::Adaptor;
154
174
155
175
LogicalResult matchAndRewrite (SourceOp op,
156
176
PatternRewriter &rewriter) const override {
157
- if (!static_cast <const SubClass *>(this )->matchOp (op))
158
- return failure ();
159
-
160
177
Location loc = op.getLoc ();
161
178
// Demaps non-trivial inputs.
162
179
SmallVector<Value> deMappedIns (op->getOperands ());
@@ -166,61 +183,119 @@ struct DemapInsRemapOutsRewriter : public OpRewritePattern<SourceOp> {
166
183
167
184
// CRTP call.
168
185
OpAdaptor adaptor (deMappedIns);
169
- ValueRange outs =
170
- static_cast <const SubClass *>(this )->rewriteOp (op, adaptor, rewriter);
171
- assert (outs.size () == op->getResults ().size ());
172
-
173
- // Remap outputs.
174
- SmallVector<Value> reMappedOuts (outs);
175
- for (auto [r, a] : llvm::zip (reMappedOuts, op->getResults ()))
176
- if (r.getType () != a.getType ())
177
- r = rewriter.create <ReinterpretMapOp>(loc, a.getType (), r);
178
-
179
- rewriter.replaceOp (op, reMappedOuts);
180
- return success ();
186
+ return static_cast <const SubClass *>(this )->rewriteOp (op, adaptor,
187
+ rewriter);
181
188
}
182
189
};
183
190
184
- struct CrdTranslateRewriter : public OpRewritePattern <CrdTranslateOp> {
185
- using OpRewritePattern::OpRewritePattern;
186
- LogicalResult matchAndRewrite (CrdTranslateOp op,
187
- PatternRewriter &rewriter) const override {
188
- AffineMap map = op.getDirection () == CrdTransDirectionKind::dim2lvl
189
- ? op.getEncoder ().getDimToLvl ()
190
- : op.getEncoder ().getLvlToDim ();
191
-
192
- SmallVector<Value> outCrds;
193
- for (AffineExpr result : map.getResults ()) {
194
- // TODO: we should probably expand the affine map to IR using our own
195
- // rules, since affine.apply assume signed value, while the cooridinates
196
- // we provided must always be signless.
197
- Value trans = rewriter.create <affine::AffineApplyOp>(
198
- op.getLoc (), AffineMap::get (map.getNumDims (), 0 , result),
199
- op.getInCrds ());
200
- outCrds.push_back (trans);
201
- }
202
- rewriter.replaceOp (op, outCrds);
191
+ struct TensorInsertDemapper
192
+ : public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> {
193
+ using DemapInsRewriter::DemapInsRewriter;
194
+ LogicalResult rewriteOp (tensor::InsertOp op, OpAdaptor adaptor,
195
+ PatternRewriter &rewriter) const {
196
+ if (!hasAnySparseResult (op))
197
+ return failure ();
198
+
199
+ Location loc = op.getLoc ();
200
+ auto stt = getSparseTensorType (op.getResult ());
201
+ ValueRange lvlCrd = stt.translateCrds (rewriter, loc, op.getIndices (),
202
+ CrdTransDirectionKind::dim2lvl);
203
+ auto insertOp = rewriter.create <sparse_tensor::InsertOp>(
204
+ loc, op.getScalar (), adaptor.getDest (), lvlCrd);
205
+
206
+ Value out = genRemap (rewriter, stt.getEncoding (), insertOp.getResult ());
207
+ rewriter.replaceOp (op, out);
203
208
return success ();
204
209
}
205
210
};
206
211
207
- struct TensorInsertRewriter
208
- : public DemapInsRemapOutsRewriter<TensorInsertRewriter, tensor::InsertOp> {
209
- using DemapInsRemapOutsRewriter::DemapInsRemapOutsRewriter;
212
+ struct ForeachOpDemapper
213
+ : public DemapInsRewriter<ForeachOpDemapper, ForeachOp> {
214
+ using DemapInsRewriter::DemapInsRewriter;
215
+ LogicalResult rewriteOp (ForeachOp op, OpAdaptor adaptor,
216
+ PatternRewriter &rewriter) const {
217
+ // Only handle operations with sparse input/output with non-identity dim2lvl
218
+ // maps.
219
+ if (!hasNonIdentityOperandsOrResults (op))
220
+ return failure ();
210
221
211
- bool matchOp (tensor::InsertOp op) const {
212
- return op.getResult ().getType ().getEncoding () != nullptr ;
213
- }
222
+ // TODO: demap constant as well.
223
+ if (auto constOp = op.getTensor ().getDefiningOp <arith::ConstantOp>())
224
+ if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue ()))
225
+ return failure ();
214
226
215
- ValueRange rewriteOp (tensor::InsertOp op, OpAdaptor adaptor,
216
- PatternRewriter &rewriter) const {
217
227
Location loc = op.getLoc ();
218
- auto stt = getSparseTensorType (op.getResult ());
219
- ValueRange lvlCrd = stt.translateCrds (rewriter, loc, op.getIndices (),
220
- CrdTransDirectionKind::dim2lvl);
221
- Operation *insertOp = rewriter.create <sparse_tensor::InsertOp>(
222
- loc, op.getScalar (), adaptor.getDest (), lvlCrd);
223
- return insertOp->getResults ();
228
+ // Cache the type information since we update the foreach op in-place.
229
+ auto srcStt = getSparseTensorType (op.getTensor ());
230
+ SmallVector<Type> prevRetTps (op.getResultTypes ());
231
+
232
+ rewriter.startRootUpdate (op);
233
+ op.getTensorMutable ().assign (adaptor.getTensor ());
234
+ op.getInitArgsMutable ().assign (adaptor.getInitArgs ());
235
+ // Update results' types.
236
+ for (auto r : op.getResults ())
237
+ if (auto stt = tryGetSparseTensorType (r); stt && !stt->isIdentity ())
238
+ r.setType (stt->getDemappedType ());
239
+
240
+ Level lvlRank = getSparseTensorType (adaptor.getTensor ()).getLvlRank ();
241
+ // Update the foreach body.
242
+ SmallVector<Type> blockArgTps (lvlRank, rewriter.getIndexType ());
243
+ blockArgTps.push_back (srcStt.getElementType ());
244
+ blockArgTps.append (adaptor.getInitArgs ().getTypes ().begin (),
245
+ adaptor.getInitArgs ().getTypes ().end ());
246
+ Block *body = op.getBody ();
247
+ // Block Args: [dimCrd, val, initArgs]
248
+ unsigned preArgNum = body->getNumArguments ();
249
+ for (Type t : blockArgTps)
250
+ body->addArgument (t, loc);
251
+
252
+ // Block Args: [dimCrd, val, initArgs, lvlCrds, val, DemappedArgs]
253
+ rewriter.setInsertionPointToStart (body);
254
+ ValueRange lvlCrds = body->getArguments ().slice (preArgNum, lvlRank);
255
+
256
+ ValueRange dimCrds = srcStt.translateCrds (rewriter, loc, lvlCrds,
257
+ CrdTransDirectionKind::lvl2dim);
258
+ rewriter.replaceAllUsesWith (
259
+ body->getArguments ().take_front (srcStt.getDimRank ()), dimCrds);
260
+ body->eraseArguments (0 , srcStt.getDimRank ());
261
+ // Block Args: [val, initArgs, lvlCrds, val, DemappedArgs]
262
+ unsigned numInitArgs = op.getInitArgs ().size ();
263
+ rewriter.replaceAllUsesWith (body->getArgument (0 ),
264
+ body->getArgument (lvlRank + numInitArgs + 1 ));
265
+ body->eraseArgument (0 );
266
+ // Block Args: [initArgs, lvlCrds, val, DemappedArgs]
267
+ ValueRange srcArgs = body->getArguments ().take_front (numInitArgs);
268
+ ValueRange dstArgs = body->getArguments ().take_back (numInitArgs);
269
+ // Remap back before replacement.
270
+ SmallVector<Value> reMappedArgs =
271
+ remapValueRange (rewriter, srcArgs.getTypes (), dstArgs);
272
+ rewriter.replaceAllUsesWith (srcArgs, reMappedArgs);
273
+ body->eraseArguments (0 , numInitArgs);
274
+ // Block Args: [lvlCrds, DemappedArgs] and we are done.
275
+
276
+ // Update yield operations.
277
+ if (numInitArgs != 0 ) {
278
+ rewriter.setInsertionPointToEnd (body);
279
+ auto yield = llvm::cast<YieldOp>(body->getTerminator ());
280
+ if (auto stt = tryGetSparseTensorType (yield.getResult ());
281
+ stt && !stt->isIdentity ()) {
282
+ Value y = genDemap (rewriter, stt->getEncoding (), yield.getResult ());
283
+ rewriter.create <YieldOp>(loc, y);
284
+ rewriter.eraseOp (yield);
285
+ }
286
+ }
287
+ rewriter.finalizeRootUpdate (op);
288
+
289
+ rewriter.setInsertionPointAfter (op);
290
+ SmallVector<Value> outs =
291
+ remapValueRange (rewriter, prevRetTps, op.getResults ());
292
+
293
+ // Replace all the uses of the foreach results, expect the use in
294
+ // reinterpret_map used to remap the output.
295
+ for (auto [from, to] : llvm::zip (op.getResults (), outs))
296
+ rewriter.replaceAllUsesExcept (from, to, to.getDefiningOp ());
297
+
298
+ return success ();
224
299
}
225
300
};
226
301
@@ -234,7 +309,7 @@ void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
234
309
}
235
310
if (scope == ReinterpretMapScope::kAll ||
236
311
scope == ReinterpretMapScope::kExceptGeneric ) {
237
- patterns.add <CrdTranslateRewriter, TensorInsertRewriter >(
312
+ patterns.add <TensorInsertDemapper, ForeachOpDemapper >(
238
313
patterns.getContext ());
239
314
}
240
315
}
0 commit comments