Skip to content

Commit 3426d33

Browse files
authored
[mlir][sparse] Implement rewriters to reinterpret maps on foreach (#70868)
1 parent a62b86a commit 3426d33

File tree

3 files changed

+193
-69
lines changed

3 files changed

+193
-69
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp

Lines changed: 131 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,8 @@
1818
using namespace mlir;
1919
using namespace mlir::sparse_tensor;
2020

21-
namespace {
22-
2321
//===----------------------------------------------------------------------===//
24-
// Helper methods.
22+
// File Local Helper methods.
2523
//===----------------------------------------------------------------------===//
2624

2725
// Translates a "simple" map according to an identity lvl-map.
@@ -51,6 +49,27 @@ static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
5149
return builder.create<ReinterpretMapOp>(val.getLoc(), enc, val);
5250
}
5351

52+
static SmallVector<Value> remapValueRange(OpBuilder &rewriter, TypeRange types,
53+
ValueRange outs) {
54+
SmallVector<Value> ret(outs);
55+
assert(outs.size() == types.size());
56+
for (auto [r, t] : llvm::zip(ret, types))
57+
if (r.getType() != t)
58+
r = rewriter.create<ReinterpretMapOp>(r.getLoc(), t, r);
59+
return ret;
60+
}
61+
62+
/// Whether the operation has any sparse tensor with non-identity dim2lvl maps.
63+
static bool hasNonIdentityOperandsOrResults(Operation *op) {
64+
auto hasNonIdentityMap = [](Value v) {
65+
auto stt = tryGetSparseTensorType(v);
66+
return stt && !stt->isIdentity();
67+
};
68+
69+
return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
70+
llvm::any_of(op->getResults(), hasNonIdentityMap);
71+
}
72+
5473
// Generates a clone of the given linalg generic operation, but with
5574
// remapped arguments, index maps, and iteration types.
5675
//
@@ -86,6 +105,8 @@ static linalg::GenericOp genGenericLinalg(PatternRewriter &rewriter,
86105
return newOp;
87106
}
88107

108+
namespace {
109+
89110
//===----------------------------------------------------------------------===//
90111
// Rewriting rules for linalg generic ops.
91112
//===----------------------------------------------------------------------===//
@@ -142,21 +163,17 @@ struct GenericOpReinterpretMap : public OpRewritePattern<linalg::GenericOp> {
142163
};
143164

144165
//===----------------------------------------------------------------------===//
145-
// Rewriting rules for operations other than linalg generic ops.
166+
// Reinterpret Map Rewriters for operations other than linalg.generics
146167
//===----------------------------------------------------------------------===//
147168

148-
// CRTP to help implementing a rewriter that demaps all its inputs and remaps
149-
// all its outputs.
169+
// CRTP to help implementing a rewriter that demaps all its inputs.
150170
template <typename SubClass, typename SourceOp>
151-
struct DemapInsRemapOutsRewriter : public OpRewritePattern<SourceOp> {
171+
struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
152172
using OpRewritePattern<SourceOp>::OpRewritePattern;
153173
using OpAdaptor = typename SourceOp::Adaptor;
154174

155175
LogicalResult matchAndRewrite(SourceOp op,
156176
PatternRewriter &rewriter) const override {
157-
if (!static_cast<const SubClass *>(this)->matchOp(op))
158-
return failure();
159-
160177
Location loc = op.getLoc();
161178
// Demaps non-trivial inputs.
162179
SmallVector<Value> deMappedIns(op->getOperands());
@@ -166,61 +183,119 @@ struct DemapInsRemapOutsRewriter : public OpRewritePattern<SourceOp> {
166183

167184
// CRTP call.
168185
OpAdaptor adaptor(deMappedIns);
169-
ValueRange outs =
170-
static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);
171-
assert(outs.size() == op->getResults().size());
172-
173-
// Remap outputs.
174-
SmallVector<Value> reMappedOuts(outs);
175-
for (auto [r, a] : llvm::zip(reMappedOuts, op->getResults()))
176-
if (r.getType() != a.getType())
177-
r = rewriter.create<ReinterpretMapOp>(loc, a.getType(), r);
178-
179-
rewriter.replaceOp(op, reMappedOuts);
180-
return success();
186+
return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor,
187+
rewriter);
181188
}
182189
};
183190

184-
struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
185-
using OpRewritePattern::OpRewritePattern;
186-
LogicalResult matchAndRewrite(CrdTranslateOp op,
187-
PatternRewriter &rewriter) const override {
188-
AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
189-
? op.getEncoder().getDimToLvl()
190-
: op.getEncoder().getLvlToDim();
191-
192-
SmallVector<Value> outCrds;
193-
for (AffineExpr result : map.getResults()) {
194-
// TODO: we should probably expand the affine map to IR using our own
195-
// rules, since affine.apply assume signed value, while the cooridinates
196-
// we provided must always be signless.
197-
Value trans = rewriter.create<affine::AffineApplyOp>(
198-
op.getLoc(), AffineMap::get(map.getNumDims(), 0, result),
199-
op.getInCrds());
200-
outCrds.push_back(trans);
201-
}
202-
rewriter.replaceOp(op, outCrds);
191+
struct TensorInsertDemapper
192+
: public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> {
193+
using DemapInsRewriter::DemapInsRewriter;
194+
LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
195+
PatternRewriter &rewriter) const {
196+
if (!hasAnySparseResult(op))
197+
return failure();
198+
199+
Location loc = op.getLoc();
200+
auto stt = getSparseTensorType(op.getResult());
201+
ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
202+
CrdTransDirectionKind::dim2lvl);
203+
auto insertOp = rewriter.create<sparse_tensor::InsertOp>(
204+
loc, op.getScalar(), adaptor.getDest(), lvlCrd);
205+
206+
Value out = genRemap(rewriter, stt.getEncoding(), insertOp.getResult());
207+
rewriter.replaceOp(op, out);
203208
return success();
204209
}
205210
};
206211

207-
struct TensorInsertRewriter
208-
: public DemapInsRemapOutsRewriter<TensorInsertRewriter, tensor::InsertOp> {
209-
using DemapInsRemapOutsRewriter::DemapInsRemapOutsRewriter;
212+
struct ForeachOpDemapper
213+
: public DemapInsRewriter<ForeachOpDemapper, ForeachOp> {
214+
using DemapInsRewriter::DemapInsRewriter;
215+
LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor,
216+
PatternRewriter &rewriter) const {
217+
// Only handle operations with sparse input/output with non-identity dim2lvl
218+
// maps.
219+
if (!hasNonIdentityOperandsOrResults(op))
220+
return failure();
210221

211-
bool matchOp(tensor::InsertOp op) const {
212-
return op.getResult().getType().getEncoding() != nullptr;
213-
}
222+
// TODO: demap constant as well.
223+
if (auto constOp = op.getTensor().getDefiningOp<arith::ConstantOp>())
224+
if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue()))
225+
return failure();
214226

215-
ValueRange rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
216-
PatternRewriter &rewriter) const {
217227
Location loc = op.getLoc();
218-
auto stt = getSparseTensorType(op.getResult());
219-
ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
220-
CrdTransDirectionKind::dim2lvl);
221-
Operation *insertOp = rewriter.create<sparse_tensor::InsertOp>(
222-
loc, op.getScalar(), adaptor.getDest(), lvlCrd);
223-
return insertOp->getResults();
228+
// Cache the type information since we update the foreach op in-place.
229+
auto srcStt = getSparseTensorType(op.getTensor());
230+
SmallVector<Type> prevRetTps(op.getResultTypes());
231+
232+
rewriter.startRootUpdate(op);
233+
op.getTensorMutable().assign(adaptor.getTensor());
234+
op.getInitArgsMutable().assign(adaptor.getInitArgs());
235+
// Update results' types.
236+
for (auto r : op.getResults())
237+
if (auto stt = tryGetSparseTensorType(r); stt && !stt->isIdentity())
238+
r.setType(stt->getDemappedType());
239+
240+
Level lvlRank = getSparseTensorType(adaptor.getTensor()).getLvlRank();
241+
// Update the foreach body.
242+
SmallVector<Type> blockArgTps(lvlRank, rewriter.getIndexType());
243+
blockArgTps.push_back(srcStt.getElementType());
244+
blockArgTps.append(adaptor.getInitArgs().getTypes().begin(),
245+
adaptor.getInitArgs().getTypes().end());
246+
Block *body = op.getBody();
247+
// Block Args: [dimCrd, val, initArgs]
248+
unsigned preArgNum = body->getNumArguments();
249+
for (Type t : blockArgTps)
250+
body->addArgument(t, loc);
251+
252+
// Block Args: [dimCrd, val, initArgs, lvlCrds, val, DemappedArgs]
253+
rewriter.setInsertionPointToStart(body);
254+
ValueRange lvlCrds = body->getArguments().slice(preArgNum, lvlRank);
255+
256+
ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds,
257+
CrdTransDirectionKind::lvl2dim);
258+
rewriter.replaceAllUsesWith(
259+
body->getArguments().take_front(srcStt.getDimRank()), dimCrds);
260+
body->eraseArguments(0, srcStt.getDimRank());
261+
// Block Args: [val, initArgs, lvlCrds, val, DemappedArgs]
262+
unsigned numInitArgs = op.getInitArgs().size();
263+
rewriter.replaceAllUsesWith(body->getArgument(0),
264+
body->getArgument(lvlRank + numInitArgs + 1));
265+
body->eraseArgument(0);
266+
// Block Args: [initArgs, lvlCrds, val, DemappedArgs]
267+
ValueRange srcArgs = body->getArguments().take_front(numInitArgs);
268+
ValueRange dstArgs = body->getArguments().take_back(numInitArgs);
269+
// Remap back before replacement.
270+
SmallVector<Value> reMappedArgs =
271+
remapValueRange(rewriter, srcArgs.getTypes(), dstArgs);
272+
rewriter.replaceAllUsesWith(srcArgs, reMappedArgs);
273+
body->eraseArguments(0, numInitArgs);
274+
// Block Args: [lvlCrds, DemappedArgs] and we are done.
275+
276+
// Update yield operations.
277+
if (numInitArgs != 0) {
278+
rewriter.setInsertionPointToEnd(body);
279+
auto yield = llvm::cast<YieldOp>(body->getTerminator());
280+
if (auto stt = tryGetSparseTensorType(yield.getResult());
281+
stt && !stt->isIdentity()) {
282+
Value y = genDemap(rewriter, stt->getEncoding(), yield.getResult());
283+
rewriter.create<YieldOp>(loc, y);
284+
rewriter.eraseOp(yield);
285+
}
286+
}
287+
rewriter.finalizeRootUpdate(op);
288+
289+
rewriter.setInsertionPointAfter(op);
290+
SmallVector<Value> outs =
291+
remapValueRange(rewriter, prevRetTps, op.getResults());
292+
293+
// Replace all the uses of the foreach results, expect the use in
294+
// reinterpret_map used to remap the output.
295+
for (auto [from, to] : llvm::zip(op.getResults(), outs))
296+
rewriter.replaceAllUsesExcept(from, to, to.getDefiningOp());
297+
298+
return success();
224299
}
225300
};
226301

@@ -234,7 +309,7 @@ void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
234309
}
235310
if (scope == ReinterpretMapScope::kAll ||
236311
scope == ReinterpretMapScope::kExceptGeneric) {
237-
patterns.add<CrdTranslateRewriter, TensorInsertRewriter>(
312+
patterns.add<TensorInsertDemapper, ForeachOpDemapper>(
238313
patterns.getContext());
239314
}
240315
}

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1063,6 +1063,29 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
10631063
}
10641064
};
10651065

1066+
struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
1067+
using OpRewritePattern::OpRewritePattern;
1068+
LogicalResult matchAndRewrite(CrdTranslateOp op,
1069+
PatternRewriter &rewriter) const override {
1070+
AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
1071+
? op.getEncoder().getDimToLvl()
1072+
: op.getEncoder().getLvlToDim();
1073+
1074+
SmallVector<Value> outCrds;
1075+
for (AffineExpr result : map.getResults()) {
1076+
// TODO: we should probably expand the affine map to IR using our own
1077+
// rules, since affine.apply assume signed value, while the cooridinates
1078+
// we provided must always be signless.
1079+
Value trans = rewriter.create<affine::AffineApplyOp>(
1080+
op.getLoc(), AffineMap::get(map.getNumDims(), 0, result),
1081+
op.getInCrds());
1082+
outCrds.push_back(trans);
1083+
}
1084+
rewriter.replaceOp(op, outCrds);
1085+
return success();
1086+
}
1087+
};
1088+
10661089
/// Sparse rewriting rule for the foreach operator.
10671090
struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
10681091
public:
@@ -1284,5 +1307,7 @@ void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
12841307
}
12851308

12861309
void mlir::populateLowerForeachToSCFPatterns(RewritePatternSet &patterns) {
1287-
patterns.add<ForeachRewriter>(patterns.getContext());
1310+
// Run CrdTranslateRewriter later in the pipeline so that operation can be
1311+
// folded before lowering to affine.apply
1312+
patterns.add<CrdTranslateRewriter, ForeachRewriter>(patterns.getContext());
12881313
}

mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,4 @@
1-
// RUN: mlir-opt %s -split-input-file --sparse-reinterpret-map | FileCheck %s
2-
3-
#SparseVector = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
4-
5-
// CHECK-LABEL: func @sparse_nop(
6-
// CHECK-SAME: %[[A0:.*]]: tensor<?xf64, #sparse_tensor.encoding<{{{.*}}}>>)
7-
// CHECK: return %[[A0]]
8-
func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
9-
return %arg0 : tensor<?xf64, #SparseVector>
10-
}
11-
12-
// -----
1+
// RUN: mlir-opt %s -split-input-file --sparse-reinterpret-map | FileCheck %s
132

143
#trait_mul = {
154
indexing_maps = [
@@ -55,3 +44,38 @@ func.func @mul(%arg0: tensor<32x32xf32>,
5544
return %0 : tensor<32x32xf32, #BSR>
5645
}
5746

47+
// -----
48+
49+
#BSR = #sparse_tensor.encoding<{
50+
map = ( i, j ) ->
51+
( i floordiv 2 : dense,
52+
j floordiv 2 : compressed,
53+
i mod 2 : dense,
54+
j mod 2 : dense
55+
)
56+
}>
57+
58+
// CHECK-LABEL: func.func @sparse_foreach_reinterpret_map(
59+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x4xf64
60+
// CHECK: %[[VAL_1:.*]] = bufferization.alloc_tensor() : tensor<2x4xf64
61+
// CHECK: %[[VAL_2:.*]] = sparse_tensor.reinterpret_map %[[VAL_0]] : tensor<2x4xf64
62+
// CHECK: %[[VAL_3:.*]] = sparse_tensor.reinterpret_map %[[VAL_1]] : tensor<2x4xf64
63+
// CHECK: %[[VAL_4:.*]] = sparse_tensor.foreach in %[[VAL_2]] init(%[[VAL_3]])
64+
// CHECK: ^bb0(%[[VAL_5:.*]]: index, %[[VAL_6:.*]]: index, %[[VAL_7:.*]]: index, %[[VAL_8:.*]]: index, %[[VAL_9:.*]]: f64, %[[VAL_10:.*]]: tensor<1x2x2x2xf64
65+
// CHECK: %[[VAL_11:.*]] = sparse_tensor.insert %[[VAL_9]] into %[[VAL_10]]{{\[}}%[[VAL_5]], %[[VAL_6]], %[[VAL_7]], %[[VAL_8]]] : tensor<1x2x2x2xf64
66+
// CHECK: sparse_tensor.yield %[[VAL_11]] : tensor<1x2x2x2xf64
67+
// CHECK: }
68+
// CHECK: %[[VAL_12:.*]] = sparse_tensor.reinterpret_map %[[VAL_4]] : tensor<1x2x2x2xf64
69+
// CHECK: %[[VAL_13:.*]] = sparse_tensor.load %[[VAL_12]] hasInserts : tensor<2x4xf64
70+
// CHECK: return %[[VAL_13]] : tensor<2x4xf64
71+
// CHECK: }
72+
func.func @sparse_foreach_reinterpret_map(%6 : tensor<2x4xf64, #BSR>) -> tensor<2x4xf64, #BSR> {
73+
%7 = bufferization.alloc_tensor() : tensor<2x4xf64, #BSR>
74+
%8 = sparse_tensor.foreach in %6 init(%7) : tensor<2x4xf64, #BSR>, tensor<2x4xf64, #BSR> -> tensor<2x4xf64, #BSR> do {
75+
^bb0(%arg0: index, %arg1: index, %arg2: f64, %arg3: tensor<2x4xf64, #BSR>):
76+
%inserted = tensor.insert %arg2 into %arg3[%arg0, %arg1] : tensor<2x4xf64, #BSR>
77+
sparse_tensor.yield %inserted : tensor<2x4xf64, #BSR>
78+
}
79+
%9 = sparse_tensor.load %8 hasInserts : tensor<2x4xf64, #BSR>
80+
return %9 : tensor<2x4xf64, #BSR>
81+
}

0 commit comments

Comments
 (0)