-
Notifications
You must be signed in to change notification settings - Fork 14.2k
Reapply "[mlir][sparse] implement lowering rules for IterateOp." #95836
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-sparse Author: Peiming Liu (PeimingLiu) ChangesFull diff: https://github.com/llvm/llvm-project/pull/95836.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index 62887c75c872b..4224925147c84 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -34,6 +34,20 @@ convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) {
return success();
}
+static std::optional<LogicalResult>
+convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) {
+ // The actually Iterator Values (that are updated every iteration).
+ auto idxTp = IndexType::get(itTp.getContext());
+ // TODO: handle batch dimension.
+ assert(itTp.getEncoding().getBatchLvlRank() == 0);
+ if (!itTp.isUnique()) {
+ // Segment high for non-unique iterator.
+ fields.push_back(idxTp);
+ }
+ fields.push_back(idxTp);
+ return success();
+}
+
namespace {
/// Sparse codegen rule for number of entries operator.
@@ -57,10 +71,114 @@ class ExtractIterSpaceConverter
}
};
+class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
+public:
+ using OneToNOpConversionPattern::OneToNOpConversionPattern;
+ LogicalResult
+ matchAndRewrite(IterateOp op, OpAdaptor adaptor,
+ OneToNPatternRewriter &rewriter) const override {
+ if (!op.getCrdUsedLvls().empty())
+ return rewriter.notifyMatchFailure(
+ op, "non-empty coordinates list not implemented.");
+
+ Location loc = op.getLoc();
+
+ auto iterSpace = SparseIterationSpace::fromValues(
+ op.getIterSpace().getType(), adaptor.getIterSpace(), 0);
+
+ std::unique_ptr<SparseIterator> it =
+ iterSpace.extractIterator(rewriter, loc);
+
+ if (it->iteratableByFor()) {
+ auto [lo, hi] = it->genForCond(rewriter, loc);
+ Value step = constantIndex(rewriter, loc, 1);
+ SmallVector<Value> ivs;
+ for (ValueRange inits : adaptor.getInitArgs())
+ llvm::append_range(ivs, inits);
+ scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, ivs);
+
+ Block *loopBody = op.getBody();
+ OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
+ if (failed(typeConverter->convertSignatureArgs(
+ loopBody->getArgumentTypes(), bodyTypeMapping)))
+ return failure();
+ rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
+
+ rewriter.eraseBlock(forOp.getBody());
+ Region &dstRegion = forOp.getRegion();
+ rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
+
+ auto yieldOp =
+ llvm::cast<sparse_tensor::YieldOp>(forOp.getBody()->getTerminator());
+
+ rewriter.setInsertionPointToEnd(forOp.getBody());
+ // replace sparse_tensor.yield with scf.yield.
+ rewriter.create<scf::YieldOp>(loc, yieldOp.getResults());
+ rewriter.eraseOp(yieldOp);
+
+ const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
+ rewriter.replaceOp(op, forOp.getResults(), resultMapping);
+ } else {
+ SmallVector<Value> ivs;
+ llvm::append_range(ivs, it->getCursor());
+ for (ValueRange inits : adaptor.getInitArgs())
+ llvm::append_range(ivs, inits);
+
+ assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
+
+ TypeRange types = ValueRange(ivs).getTypes();
+ auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
+ SmallVector<Location> l(types.size(), op.getIterator().getLoc());
+
+ // Generates loop conditions.
+ Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
+ rewriter.setInsertionPointToStart(before);
+ ValueRange bArgs = before->getArguments();
+ auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
+ assert(remArgs.size() == adaptor.getInitArgs().size());
+ rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
+
+ // Generates loop body.
+ Block *loopBody = op.getBody();
+ OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
+ if (failed(typeConverter->convertSignatureArgs(
+ loopBody->getArgumentTypes(), bodyTypeMapping)))
+ return failure();
+ rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
+
+ Region &dstRegion = whileOp.getAfter();
+ // TODO: handle uses of coordinate!
+ rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
+ ValueRange aArgs = whileOp.getAfterArguments();
+ auto yieldOp = llvm::cast<sparse_tensor::YieldOp>(
+ whileOp.getAfterBody()->getTerminator());
+
+ rewriter.setInsertionPointToEnd(whileOp.getAfterBody());
+
+ aArgs = it->linkNewScope(aArgs);
+ ValueRange nx = it->forward(rewriter, loc);
+ SmallVector<Value> yields;
+ llvm::append_range(yields, nx);
+ llvm::append_range(yields, yieldOp.getResults());
+
+ // replace sparse_tensor.yield with scf.yield.
+ rewriter.eraseOp(yieldOp);
+ rewriter.create<scf::YieldOp>(loc, yields);
+
+ const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
+ rewriter.replaceOp(
+ op, whileOp.getResults().drop_front(it->getCursor().size()),
+ resultMapping);
+ }
+ return success();
+ }
+};
+
} // namespace
mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
addConversion([](Type type) { return type; });
+ addConversion(convertIteratorType);
addConversion(convertIterSpaceType);
addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp,
@@ -74,5 +192,6 @@ mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
void mlir::populateLowerSparseIterationToSCFPatterns(
TypeConverter &converter, RewritePatternSet &patterns) {
- patterns.add<ExtractIterSpaceConverter>(converter, patterns.getContext());
+ patterns.add<ExtractIterSpaceConverter, SparseIterateOpConverter>(
+ converter, patterns.getContext());
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index be8e15d6ae6f4..ef95fcc84bd90 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -331,6 +331,13 @@ class TrivialIterator : public ConcreteIterator {
TrivialIterator(const SparseTensorLevel &stl)
: ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1) {}
+ TrivialIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl,
+ Value posLo, Value posHi)
+ : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1), posLo(posLo),
+ posHi(posHi) {
+ seek(posLo);
+ }
+
std::string getDebugInterfacePrefix() const override {
return std::string("trivial<") + stl.toString() + ">";
}
@@ -420,6 +427,14 @@ class DedupIterator : public ConcreteIterator {
: ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2) {
assert(!stl.isUnique());
}
+
+ DedupIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl,
+ Value posLo, Value posHi)
+ : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2), posHi(posHi) {
+ assert(!stl.isUnique());
+ seek({posLo, genSegmentHigh(b, l, posLo)});
+ }
+
// For LLVM-style RTTI.
static bool classof(const SparseIterator *from) {
return from->kind == IterKind::kDedup;
@@ -1532,6 +1547,11 @@ SparseIterationSpace mlir::sparse_tensor::SparseIterationSpace::fromValues(
return space;
}
+std::unique_ptr<SparseIterator>
+SparseIterationSpace::extractIterator(OpBuilder &b, Location l) const {
+ return makeSimpleIterator(b, l, *this);
+}
+
//===----------------------------------------------------------------------===//
// SparseIterator factory functions.
//===----------------------------------------------------------------------===//
@@ -1590,6 +1610,26 @@ sparse_tensor::makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl,
return std::make_pair(std::move(stl), std::move(it));
}
+std::unique_ptr<SparseIterator>
+sparse_tensor::makeSimpleIterator(OpBuilder &b, Location l,
+ const SparseIterationSpace &iterSpace) {
+ // assert(iterSpace.getSpaceDim() == 1);
+ std::unique_ptr<SparseIterator> ret;
+ if (!iterSpace.isUnique()) {
+ // We always dedupliate the non-unique level, but we should optimize it away
+ // if possible.
+ ret = std::make_unique<DedupIterator>(b, l, iterSpace.getLastLvl(),
+ iterSpace.getBoundLo(),
+ iterSpace.getBoundHi());
+ } else {
+ ret = std::make_unique<TrivialIterator>(b, l, iterSpace.getLastLvl(),
+ iterSpace.getBoundLo(),
+ iterSpace.getBoundHi());
+ }
+ ret->setSparseEmitStrategy(SparseEmitStrategy::kFunctional);
+ return ret;
+}
+
std::unique_ptr<SparseIterator>
sparse_tensor::makeSimpleIterator(const SparseTensorLevel &stl,
SparseEmitStrategy strategy) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 17636af2b2f9d..91f363db93f1d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -132,6 +132,10 @@ class SparseIterationSpace {
Value getBoundLo() const { return bound.first; }
Value getBoundHi() const { return bound.second; }
+ // Extract an iterator to iterate over the sparse iteration space.
+ std::unique_ptr<SparseIterator> extractIterator(OpBuilder &b,
+ Location l) const;
+
private:
SmallVector<std::unique_ptr<SparseTensorLevel>> lvls;
std::pair<Value, Value> bound;
@@ -192,6 +196,13 @@ class SparseIterator {
crd = nullptr;
}
+ // Reconstructs a iteration space directly from the provided ValueRange.
+ static std::unique_ptr<SparseIterator>
+ fromValues(IteratorType dstTp, ValueRange values, unsigned tid);
+
+ // The inverse operation of `fromValues`.
+ SmallVector<Value> toValues() const { llvm_unreachable("Not implemented"); }
+
//
// Iterator properties.
//
@@ -345,12 +356,21 @@ std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &b,
unsigned tid,
Level lvl);
-/// Helper function to create a TensorLevel object from given `tensor`.
+/// Helper function to create a TensorLevel object from given ValueRange.
std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(LevelType lt, Value sz,
ValueRange buffers,
unsigned tid, Level l);
-/// Helper function to create a simple SparseIterator object that iterates
-/// over the SparseTensorLevel.
+
+/// Helper function to create a simple SparseIterator object that iterate
+/// over the entire iteration space.
+std::unique_ptr<SparseIterator>
+makeSimpleIterator(OpBuilder &b, Location l,
+ const SparseIterationSpace &iterSpace);
+
+/// Helper function to create a simple SparseIterator object that iterate
+/// over the sparse tensor level.
+/// TODO: switch to `SparseIterationSpace` (which support N-D iterator) when
+/// feature complete.
std::unique_ptr<SparseIterator> makeSimpleIterator(
const SparseTensorLevel &stl,
SparseEmitStrategy strategy = SparseEmitStrategy::kFunctional);
diff --git a/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir b/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir
index 5fcd661bb69b2..77a0e89dc7c81 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s --lower-sparse-iteration-to-scf | FileCheck %s
+// RUN: mlir-opt %s --sparse-space-collapse --lower-sparse-iteration-to-scf | FileCheck %s --check-prefix COLLAPSED
#COO = #sparse_tensor.encoding<{
map = (i, j) -> (
@@ -7,17 +8,44 @@
)
}>
-// CHECK-LABEL: func.func @sparse_1D_space(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32, #sparse{{[0-9]*}}>) -> !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0> {
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[LVL_SIZE:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[C0]] : tensor<?x?xf32, #sparse{{[0-9]*}}>
-// CHECK: %[[POS_MEM:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK: %[[CRD_MEM:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK: %[[POS_LO:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C0]]] : memref<?xindex>
-// CHECK: %[[POS_HI:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C1]]] : memref<?xindex>
-// CHECK: %[[ITER_SPACE:.*]] = builtin.unrealized_conversion_cast %[[POS_MEM]], %[[CRD_MEM]], %[[LVL_SIZE]], %[[POS_LO]], %[[POS_HI]]
-func.func @sparse_1D_space(%sp : tensor<?x?xf32, #COO>) -> !sparse_tensor.iter_space<#COO, lvls = 0> {
- %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<?x?xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
- return %l1 : !sparse_tensor.iter_space<#COO, lvls = 0>
+// CHECK-LABEL: @sparse_iteration_to_scf
+// // deduplication
+// CHECK: scf.while {{.*}} {
+// CHECK: } do {
+// CHECK: }
+// CHECK: scf.while {{.*}} {
+// CHECK: } do {
+// // actual computation
+// CHECK: scf.for {{.*}} {
+// CHECK: arith.addi
+// CHECK: }
+// // deduplication
+// CHECK: scf.while {{.*}} {
+// CHECK: } do {
+// CHECK: }
+// CHECK: scf.yield
+// CHECK: }
+// CHECK: return
+
+// COLLAPSED-LABEL: @sparse_iteration_to_scf
+// COLLAPSED: %[[RET:.*]] = scf.for {{.*}} {
+// COLLAPSED: %[[VAL:.*]] = arith.addi
+// COLLAPSED: scf.yield %[[VAL]] : index
+// COLLAPSED: }
+// COLLAPSED: return %[[RET]] : index
+func.func @sparse_iteration_to_scf(%sp : tensor<4x8xf32, #COO>) -> index {
+ %i = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0
+ : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
+ %r1 = sparse_tensor.iterate %it1 in %l1 iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
+ %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
+ : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1> -> !sparse_tensor.iter_space<#COO, lvls = 1>
+ %r2 = sparse_tensor.iterate %it2 in %l2 iter_args(%inner = %outer): !sparse_tensor.iter_space<#COO, lvls = 1 to 2> -> index {
+ %k = arith.addi %inner, %c1 : index
+ sparse_tensor.yield %k : index
+ }
+ sparse_tensor.yield %r2 : index
+ }
+ return %r1 : index
}
|
aartbik
approved these changes
Jun 17, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
No description provided.