Skip to content

[mlir][sparse] Implement rewriters to reinterpret maps on foreach #70868

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 131 additions & 56 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@
using namespace mlir;
using namespace mlir::sparse_tensor;

namespace {

//===----------------------------------------------------------------------===//
// Helper methods.
// File Local Helper methods.
//===----------------------------------------------------------------------===//

// Translates a "simple" map according to an identity lvl-map.
Expand Down Expand Up @@ -51,6 +49,27 @@ static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
return builder.create<ReinterpretMapOp>(val.getLoc(), enc, val);
}

static SmallVector<Value> remapValueRange(OpBuilder &rewriter, TypeRange types,
ValueRange outs) {
SmallVector<Value> ret(outs);
assert(outs.size() == types.size());
for (auto [r, t] : llvm::zip(ret, types))
if (r.getType() != t)
r = rewriter.create<ReinterpretMapOp>(r.getLoc(), t, r);
return ret;
}

/// Whether the operation has any sparse tensor with non-identity dim2lvl maps.
static bool hasNonIdentityOperandsOrResults(Operation *op) {
auto hasNonIdentityMap = [](Value v) {
auto stt = tryGetSparseTensorType(v);
return stt && !stt->isIdentity();
};

return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
llvm::any_of(op->getResults(), hasNonIdentityMap);
}

// Generates a clone of the given linalg generic operation, but with
// remapped arguments, index maps, and iteration types.
//
Expand Down Expand Up @@ -86,6 +105,8 @@ static linalg::GenericOp genGenericLinalg(PatternRewriter &rewriter,
return newOp;
}

namespace {

//===----------------------------------------------------------------------===//
// Rewriting rules for linalg generic ops.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -142,21 +163,17 @@ struct GenericOpReinterpretMap : public OpRewritePattern<linalg::GenericOp> {
};

//===----------------------------------------------------------------------===//
// Rewriting rules for operations other than linalg generic ops.
// Reinterpret Map Rewriters for operations other than linalg.generics
//===----------------------------------------------------------------------===//

// CRTP to help implementing a rewriter that demaps all its inputs and remaps
// all its outputs.
// CRTP to help implementing a rewriter that demaps all its inputs.
template <typename SubClass, typename SourceOp>
struct DemapInsRemapOutsRewriter : public OpRewritePattern<SourceOp> {
struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
using OpRewritePattern<SourceOp>::OpRewritePattern;
using OpAdaptor = typename SourceOp::Adaptor;

LogicalResult matchAndRewrite(SourceOp op,
PatternRewriter &rewriter) const override {
if (!static_cast<const SubClass *>(this)->matchOp(op))
return failure();

Location loc = op.getLoc();
// Demaps non-trivial inputs.
SmallVector<Value> deMappedIns(op->getOperands());
Expand All @@ -166,61 +183,119 @@ struct DemapInsRemapOutsRewriter : public OpRewritePattern<SourceOp> {

// CRTP call.
OpAdaptor adaptor(deMappedIns);
ValueRange outs =
static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);
assert(outs.size() == op->getResults().size());

// Remap outputs.
SmallVector<Value> reMappedOuts(outs);
for (auto [r, a] : llvm::zip(reMappedOuts, op->getResults()))
if (r.getType() != a.getType())
r = rewriter.create<ReinterpretMapOp>(loc, a.getType(), r);

rewriter.replaceOp(op, reMappedOuts);
return success();
return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor,
rewriter);
}
};

struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(CrdTranslateOp op,
PatternRewriter &rewriter) const override {
AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
? op.getEncoder().getDimToLvl()
: op.getEncoder().getLvlToDim();

SmallVector<Value> outCrds;
for (AffineExpr result : map.getResults()) {
// TODO: we should probably expand the affine map to IR using our own
// rules, since affine.apply assume signed value, while the cooridinates
// we provided must always be signless.
Value trans = rewriter.create<affine::AffineApplyOp>(
op.getLoc(), AffineMap::get(map.getNumDims(), 0, result),
op.getInCrds());
outCrds.push_back(trans);
}
rewriter.replaceOp(op, outCrds);
struct TensorInsertDemapper
: public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> {
using DemapInsRewriter::DemapInsRewriter;
LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
PatternRewriter &rewriter) const {
if (!hasAnySparseResult(op))
return failure();

Location loc = op.getLoc();
auto stt = getSparseTensorType(op.getResult());
ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
CrdTransDirectionKind::dim2lvl);
auto insertOp = rewriter.create<sparse_tensor::InsertOp>(
loc, op.getScalar(), adaptor.getDest(), lvlCrd);

Value out = genRemap(rewriter, stt.getEncoding(), insertOp.getResult());
rewriter.replaceOp(op, out);
return success();
}
};

struct TensorInsertRewriter
: public DemapInsRemapOutsRewriter<TensorInsertRewriter, tensor::InsertOp> {
using DemapInsRemapOutsRewriter::DemapInsRemapOutsRewriter;
struct ForeachOpDemapper
: public DemapInsRewriter<ForeachOpDemapper, ForeachOp> {
using DemapInsRewriter::DemapInsRewriter;
LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor,
PatternRewriter &rewriter) const {
// Only handle operations with sparse input/output with non-identity dim2lvl
// maps.
if (!hasNonIdentityOperandsOrResults(op))
return failure();

bool matchOp(tensor::InsertOp op) const {
return op.getResult().getType().getEncoding() != nullptr;
}
// TODO: demap constant as well.
if (auto constOp = op.getTensor().getDefiningOp<arith::ConstantOp>())
if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue()))
return failure();

ValueRange rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
auto stt = getSparseTensorType(op.getResult());
ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
CrdTransDirectionKind::dim2lvl);
Operation *insertOp = rewriter.create<sparse_tensor::InsertOp>(
loc, op.getScalar(), adaptor.getDest(), lvlCrd);
return insertOp->getResults();
// Cache the type information since we update the foreach op in-place.
auto srcStt = getSparseTensorType(op.getTensor());
SmallVector<Type> prevRetTps(op.getResultTypes());

rewriter.startRootUpdate(op);
op.getTensorMutable().assign(adaptor.getTensor());
op.getInitArgsMutable().assign(adaptor.getInitArgs());
// Update results' types.
for (auto r : op.getResults())
if (auto stt = tryGetSparseTensorType(r); stt && !stt->isIdentity())
r.setType(stt->getDemappedType());

Level lvlRank = getSparseTensorType(adaptor.getTensor()).getLvlRank();
// Update the foreach body.
SmallVector<Type> blockArgTps(lvlRank, rewriter.getIndexType());
blockArgTps.push_back(srcStt.getElementType());
blockArgTps.append(adaptor.getInitArgs().getTypes().begin(),
adaptor.getInitArgs().getTypes().end());
Block *body = op.getBody();
// Block Args: [dimCrd, val, initArgs]
unsigned preArgNum = body->getNumArguments();
for (Type t : blockArgTps)
body->addArgument(t, loc);

// Block Args: [dimCrd, val, initArgs, lvlCrds, val, DemappedArgs]
rewriter.setInsertionPointToStart(body);
ValueRange lvlCrds = body->getArguments().slice(preArgNum, lvlRank);

ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds,
CrdTransDirectionKind::lvl2dim);
rewriter.replaceAllUsesWith(
body->getArguments().take_front(srcStt.getDimRank()), dimCrds);
body->eraseArguments(0, srcStt.getDimRank());
// Block Args: [val, initArgs, lvlCrds, val, DemappedArgs]
unsigned numInitArgs = op.getInitArgs().size();
rewriter.replaceAllUsesWith(body->getArgument(0),
body->getArgument(lvlRank + numInitArgs + 1));
body->eraseArgument(0);
// Block Args: [initArgs, lvlCrds, val, DemappedArgs]
ValueRange srcArgs = body->getArguments().take_front(numInitArgs);
ValueRange dstArgs = body->getArguments().take_back(numInitArgs);
// Remap back before replacement.
SmallVector<Value> reMappedArgs =
remapValueRange(rewriter, srcArgs.getTypes(), dstArgs);
rewriter.replaceAllUsesWith(srcArgs, reMappedArgs);
body->eraseArguments(0, numInitArgs);
// Block Args: [lvlCrds, DemappedArgs] and we are done.

// Update yield operations.
if (numInitArgs != 0) {
rewriter.setInsertionPointToEnd(body);
auto yield = llvm::cast<YieldOp>(body->getTerminator());
if (auto stt = tryGetSparseTensorType(yield.getResult());
stt && !stt->isIdentity()) {
Value y = genDemap(rewriter, stt->getEncoding(), yield.getResult());
rewriter.create<YieldOp>(loc, y);
rewriter.eraseOp(yield);
}
}
rewriter.finalizeRootUpdate(op);

rewriter.setInsertionPointAfter(op);
SmallVector<Value> outs =
remapValueRange(rewriter, prevRetTps, op.getResults());

// Replace all the uses of the foreach results, expect the use in
// reinterpret_map used to remap the output.
for (auto [from, to] : llvm::zip(op.getResults(), outs))
rewriter.replaceAllUsesExcept(from, to, to.getDefiningOp());

return success();
}
};

Expand All @@ -234,7 +309,7 @@ void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
}
if (scope == ReinterpretMapScope::kAll ||
scope == ReinterpretMapScope::kExceptGeneric) {
patterns.add<CrdTranslateRewriter, TensorInsertRewriter>(
patterns.add<TensorInsertDemapper, ForeachOpDemapper>(
patterns.getContext());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,29 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
}
};

struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(CrdTranslateOp op,
PatternRewriter &rewriter) const override {
AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
? op.getEncoder().getDimToLvl()
: op.getEncoder().getLvlToDim();

SmallVector<Value> outCrds;
for (AffineExpr result : map.getResults()) {
// TODO: we should probably expand the affine map to IR using our own
// rules, since affine.apply assume signed value, while the cooridinates
// we provided must always be signless.
Value trans = rewriter.create<affine::AffineApplyOp>(
op.getLoc(), AffineMap::get(map.getNumDims(), 0, result),
op.getInCrds());
outCrds.push_back(trans);
}
rewriter.replaceOp(op, outCrds);
return success();
}
};

/// Sparse rewriting rule for the foreach operator.
struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
public:
Expand Down Expand Up @@ -1284,5 +1307,7 @@ void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
}

void mlir::populateLowerForeachToSCFPatterns(RewritePatternSet &patterns) {
patterns.add<ForeachRewriter>(patterns.getContext());
// Run CrdTranslateRewriter later in the pipeline so that operation can be
// folded before lowering to affine.apply
patterns.add<CrdTranslateRewriter, ForeachRewriter>(patterns.getContext());
}
48 changes: 36 additions & 12 deletions mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
Original file line number Diff line number Diff line change
@@ -1,15 +1,4 @@
// RUN: mlir-opt %s -split-input-file --sparse-reinterpret-map | FileCheck %s

#SparseVector = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>

// CHECK-LABEL: func @sparse_nop(
// CHECK-SAME: %[[A0:.*]]: tensor<?xf64, #sparse_tensor.encoding<{{{.*}}}>>)
// CHECK: return %[[A0]]
func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
return %arg0 : tensor<?xf64, #SparseVector>
}

// -----
// RUN: mlir-opt %s -split-input-file --sparse-reinterpret-map | FileCheck %s

#trait_mul = {
indexing_maps = [
Expand Down Expand Up @@ -55,3 +44,38 @@ func.func @mul(%arg0: tensor<32x32xf32>,
return %0 : tensor<32x32xf32, #BSR>
}

// -----

#BSR = #sparse_tensor.encoding<{
map = ( i, j ) ->
( i floordiv 2 : dense,
j floordiv 2 : compressed,
i mod 2 : dense,
j mod 2 : dense
)
}>

// CHECK-LABEL: func.func @sparse_foreach_reinterpret_map(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x4xf64
// CHECK: %[[VAL_1:.*]] = bufferization.alloc_tensor() : tensor<2x4xf64
// CHECK: %[[VAL_2:.*]] = sparse_tensor.reinterpret_map %[[VAL_0]] : tensor<2x4xf64
// CHECK: %[[VAL_3:.*]] = sparse_tensor.reinterpret_map %[[VAL_1]] : tensor<2x4xf64
// CHECK: %[[VAL_4:.*]] = sparse_tensor.foreach in %[[VAL_2]] init(%[[VAL_3]])
// CHECK: ^bb0(%[[VAL_5:.*]]: index, %[[VAL_6:.*]]: index, %[[VAL_7:.*]]: index, %[[VAL_8:.*]]: index, %[[VAL_9:.*]]: f64, %[[VAL_10:.*]]: tensor<1x2x2x2xf64
// CHECK: %[[VAL_11:.*]] = sparse_tensor.insert %[[VAL_9]] into %[[VAL_10]]{{\[}}%[[VAL_5]], %[[VAL_6]], %[[VAL_7]], %[[VAL_8]]] : tensor<1x2x2x2xf64
// CHECK: sparse_tensor.yield %[[VAL_11]] : tensor<1x2x2x2xf64
// CHECK: }
// CHECK: %[[VAL_12:.*]] = sparse_tensor.reinterpret_map %[[VAL_4]] : tensor<1x2x2x2xf64
// CHECK: %[[VAL_13:.*]] = sparse_tensor.load %[[VAL_12]] hasInserts : tensor<2x4xf64
// CHECK: return %[[VAL_13]] : tensor<2x4xf64
// CHECK: }
func.func @sparse_foreach_reinterpret_map(%6 : tensor<2x4xf64, #BSR>) -> tensor<2x4xf64, #BSR> {
%7 = bufferization.alloc_tensor() : tensor<2x4xf64, #BSR>
%8 = sparse_tensor.foreach in %6 init(%7) : tensor<2x4xf64, #BSR>, tensor<2x4xf64, #BSR> -> tensor<2x4xf64, #BSR> do {
^bb0(%arg0: index, %arg1: index, %arg2: f64, %arg3: tensor<2x4xf64, #BSR>):
%inserted = tensor.insert %arg2 into %arg3[%arg0, %arg1] : tensor<2x4xf64, #BSR>
sparse_tensor.yield %inserted : tensor<2x4xf64, #BSR>
}
%9 = sparse_tensor.load %8 hasInserts : tensor<2x4xf64, #BSR>
return %9 : tensor<2x4xf64, #BSR>
}