Skip to content

Reapply "[mlir][sparse] implement lowering rules for IterateOp." #95836

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 120 additions & 1 deletion mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,20 @@ convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) {
return success();
}

static std::optional<LogicalResult>
convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) {
// The actually Iterator Values (that are updated every iteration).
auto idxTp = IndexType::get(itTp.getContext());
// TODO: handle batch dimension.
assert(itTp.getEncoding().getBatchLvlRank() == 0);
if (!itTp.isUnique()) {
// Segment high for non-unique iterator.
fields.push_back(idxTp);
}
fields.push_back(idxTp);
return success();
}

namespace {

/// Sparse codegen rule for number of entries operator.
Expand All @@ -57,10 +71,114 @@ class ExtractIterSpaceConverter
}
};

class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
public:
using OneToNOpConversionPattern::OneToNOpConversionPattern;
LogicalResult
matchAndRewrite(IterateOp op, OpAdaptor adaptor,
OneToNPatternRewriter &rewriter) const override {
if (!op.getCrdUsedLvls().empty())
return rewriter.notifyMatchFailure(
op, "non-empty coordinates list not implemented.");

Location loc = op.getLoc();

auto iterSpace = SparseIterationSpace::fromValues(
op.getIterSpace().getType(), adaptor.getIterSpace(), 0);

std::unique_ptr<SparseIterator> it =
iterSpace.extractIterator(rewriter, loc);

if (it->iteratableByFor()) {
auto [lo, hi] = it->genForCond(rewriter, loc);
Value step = constantIndex(rewriter, loc, 1);
SmallVector<Value> ivs;
for (ValueRange inits : adaptor.getInitArgs())
llvm::append_range(ivs, inits);
scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, ivs);

Block *loopBody = op.getBody();
OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
if (failed(typeConverter->convertSignatureArgs(
loopBody->getArgumentTypes(), bodyTypeMapping)))
return failure();
rewriter.applySignatureConversion(loopBody, bodyTypeMapping);

rewriter.eraseBlock(forOp.getBody());
Region &dstRegion = forOp.getRegion();
rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());

auto yieldOp =
llvm::cast<sparse_tensor::YieldOp>(forOp.getBody()->getTerminator());

rewriter.setInsertionPointToEnd(forOp.getBody());
// replace sparse_tensor.yield with scf.yield.
rewriter.create<scf::YieldOp>(loc, yieldOp.getResults());
rewriter.eraseOp(yieldOp);

const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
rewriter.replaceOp(op, forOp.getResults(), resultMapping);
} else {
SmallVector<Value> ivs;
llvm::append_range(ivs, it->getCursor());
for (ValueRange inits : adaptor.getInitArgs())
llvm::append_range(ivs, inits);

assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));

TypeRange types = ValueRange(ivs).getTypes();
auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
SmallVector<Location> l(types.size(), op.getIterator().getLoc());

// Generates loop conditions.
Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
rewriter.setInsertionPointToStart(before);
ValueRange bArgs = before->getArguments();
auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
assert(remArgs.size() == adaptor.getInitArgs().size());
rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments());

// Generates loop body.
Block *loopBody = op.getBody();
OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
if (failed(typeConverter->convertSignatureArgs(
loopBody->getArgumentTypes(), bodyTypeMapping)))
return failure();
rewriter.applySignatureConversion(loopBody, bodyTypeMapping);

Region &dstRegion = whileOp.getAfter();
// TODO: handle uses of coordinate!
rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
ValueRange aArgs = whileOp.getAfterArguments();
auto yieldOp = llvm::cast<sparse_tensor::YieldOp>(
whileOp.getAfterBody()->getTerminator());

rewriter.setInsertionPointToEnd(whileOp.getAfterBody());

aArgs = it->linkNewScope(aArgs);
ValueRange nx = it->forward(rewriter, loc);
SmallVector<Value> yields;
llvm::append_range(yields, nx);
llvm::append_range(yields, yieldOp.getResults());

// replace sparse_tensor.yield with scf.yield.
rewriter.eraseOp(yieldOp);
rewriter.create<scf::YieldOp>(loc, yields);

const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
rewriter.replaceOp(
op, whileOp.getResults().drop_front(it->getCursor().size()),
resultMapping);
}
return success();
}
};

} // namespace

mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
addConversion([](Type type) { return type; });
addConversion(convertIteratorType);
addConversion(convertIterSpaceType);

addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp,
Expand All @@ -74,5 +192,6 @@ mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {

void mlir::populateLowerSparseIterationToSCFPatterns(
TypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<ExtractIterSpaceConverter>(converter, patterns.getContext());
patterns.add<ExtractIterSpaceConverter, SparseIterateOpConverter>(
converter, patterns.getContext());
}
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,13 @@ class TrivialIterator : public ConcreteIterator {
TrivialIterator(const SparseTensorLevel &stl)
: ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1) {}

TrivialIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl,
Value posLo, Value posHi)
: ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1), posLo(posLo),
posHi(posHi) {
seek(posLo);
}

std::string getDebugInterfacePrefix() const override {
return std::string("trivial<") + stl.toString() + ">";
}
Expand Down Expand Up @@ -420,6 +427,14 @@ class DedupIterator : public ConcreteIterator {
: ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2) {
assert(!stl.isUnique());
}

DedupIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl,
Value posLo, Value posHi)
: ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2), posHi(posHi) {
assert(!stl.isUnique());
seek({posLo, genSegmentHigh(b, l, posLo)});
}

// For LLVM-style RTTI.
static bool classof(const SparseIterator *from) {
return from->kind == IterKind::kDedup;
Expand Down Expand Up @@ -1532,6 +1547,11 @@ SparseIterationSpace mlir::sparse_tensor::SparseIterationSpace::fromValues(
return space;
}

std::unique_ptr<SparseIterator>
SparseIterationSpace::extractIterator(OpBuilder &b, Location l) const {
return makeSimpleIterator(b, l, *this);
}

//===----------------------------------------------------------------------===//
// SparseIterator factory functions.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1590,6 +1610,26 @@ sparse_tensor::makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl,
return std::make_pair(std::move(stl), std::move(it));
}

std::unique_ptr<SparseIterator>
sparse_tensor::makeSimpleIterator(OpBuilder &b, Location l,
const SparseIterationSpace &iterSpace) {
// assert(iterSpace.getSpaceDim() == 1);
std::unique_ptr<SparseIterator> ret;
if (!iterSpace.isUnique()) {
// We always dedupliate the non-unique level, but we should optimize it away
// if possible.
ret = std::make_unique<DedupIterator>(b, l, iterSpace.getLastLvl(),
iterSpace.getBoundLo(),
iterSpace.getBoundHi());
} else {
ret = std::make_unique<TrivialIterator>(b, l, iterSpace.getLastLvl(),
iterSpace.getBoundLo(),
iterSpace.getBoundHi());
}
ret->setSparseEmitStrategy(SparseEmitStrategy::kFunctional);
return ret;
}

std::unique_ptr<SparseIterator>
sparse_tensor::makeSimpleIterator(const SparseTensorLevel &stl,
SparseEmitStrategy strategy) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ class SparseIterationSpace {
Value getBoundLo() const { return bound.first; }
Value getBoundHi() const { return bound.second; }

// Extract an iterator to iterate over the sparse iteration space.
std::unique_ptr<SparseIterator> extractIterator(OpBuilder &b,
Location l) const;

private:
SmallVector<std::unique_ptr<SparseTensorLevel>> lvls;
std::pair<Value, Value> bound;
Expand Down Expand Up @@ -192,6 +196,13 @@ class SparseIterator {
crd = nullptr;
}

// Reconstructs a iteration space directly from the provided ValueRange.
static std::unique_ptr<SparseIterator>
fromValues(IteratorType dstTp, ValueRange values, unsigned tid);

// The inverse operation of `fromValues`.
SmallVector<Value> toValues() const { llvm_unreachable("Not implemented"); }

//
// Iterator properties.
//
Expand Down Expand Up @@ -345,12 +356,21 @@ std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &b,
unsigned tid,
Level lvl);

/// Helper function to create a TensorLevel object from given `tensor`.
/// Helper function to create a TensorLevel object from given ValueRange.
std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(LevelType lt, Value sz,
ValueRange buffers,
unsigned tid, Level l);
/// Helper function to create a simple SparseIterator object that iterates
/// over the SparseTensorLevel.

/// Helper function to create a simple SparseIterator object that iterate
/// over the entire iteration space.
std::unique_ptr<SparseIterator>
makeSimpleIterator(OpBuilder &b, Location l,
const SparseIterationSpace &iterSpace);

/// Helper function to create a simple SparseIterator object that iterate
/// over the sparse tensor level.
/// TODO: switch to `SparseIterationSpace` (which support N-D iterator) when
/// feature complete.
std::unique_ptr<SparseIterator> makeSimpleIterator(
const SparseTensorLevel &stl,
SparseEmitStrategy strategy = SparseEmitStrategy::kFunctional);
Expand Down
54 changes: 41 additions & 13 deletions mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s --lower-sparse-iteration-to-scf | FileCheck %s
// RUN: mlir-opt %s --sparse-space-collapse --lower-sparse-iteration-to-scf | FileCheck %s --check-prefix COLLAPSED

#COO = #sparse_tensor.encoding<{
map = (i, j) -> (
Expand All @@ -7,17 +8,44 @@
)
}>

// CHECK-LABEL: func.func @sparse_1D_space(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32, #sparse{{[0-9]*}}>) -> !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0> {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[LVL_SIZE:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[C0]] : tensor<?x?xf32, #sparse{{[0-9]*}}>
// CHECK: %[[POS_MEM:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
// CHECK: %[[CRD_MEM:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
// CHECK: %[[POS_LO:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C0]]] : memref<?xindex>
// CHECK: %[[POS_HI:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C1]]] : memref<?xindex>
// CHECK: %[[ITER_SPACE:.*]] = builtin.unrealized_conversion_cast %[[POS_MEM]], %[[CRD_MEM]], %[[LVL_SIZE]], %[[POS_LO]], %[[POS_HI]]
func.func @sparse_1D_space(%sp : tensor<?x?xf32, #COO>) -> !sparse_tensor.iter_space<#COO, lvls = 0> {
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<?x?xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
return %l1 : !sparse_tensor.iter_space<#COO, lvls = 0>
// CHECK-LABEL: @sparse_iteration_to_scf
// // deduplication
// CHECK: scf.while {{.*}} {
// CHECK: } do {
// CHECK: }
// CHECK: scf.while {{.*}} {
// CHECK: } do {
// // actual computation
// CHECK: scf.for {{.*}} {
// CHECK: arith.addi
// CHECK: }
// // deduplication
// CHECK: scf.while {{.*}} {
// CHECK: } do {
// CHECK: }
// CHECK: scf.yield
// CHECK: }
// CHECK: return

// COLLAPSED-LABEL: @sparse_iteration_to_scf
// COLLAPSED: %[[RET:.*]] = scf.for {{.*}} {
// COLLAPSED: %[[VAL:.*]] = arith.addi
// COLLAPSED: scf.yield %[[VAL]] : index
// COLLAPSED: }
// COLLAPSED: return %[[RET]] : index
func.func @sparse_iteration_to_scf(%sp : tensor<4x8xf32, #COO>) -> index {
%i = arith.constant 0 : index
%c1 = arith.constant 1 : index
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0
: tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
%r1 = sparse_tensor.iterate %it1 in %l1 iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
%l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
: tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1> -> !sparse_tensor.iter_space<#COO, lvls = 1>
%r2 = sparse_tensor.iterate %it2 in %l2 iter_args(%inner = %outer): !sparse_tensor.iter_space<#COO, lvls = 1 to 2> -> index {
%k = arith.addi %inner, %c1 : index
sparse_tensor.yield %k : index
}
sparse_tensor.yield %r2 : index
}
return %r1 : index
}
Loading