Skip to content

Reapply "[mlir][sparse] implement lowering rules for IterateOp." #95836

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 17, 2024

Conversation

PeimingLiu
Copy link
Member

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Jun 17, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sparse

Author: Peiming Liu (PeimingLiu)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/95836.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp (+120-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp (+40)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h (+23-3)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir (+41-13)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index 62887c75c872b..4224925147c84 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -34,6 +34,20 @@ convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) {
   return success();
 }
 
+static std::optional<LogicalResult>
+convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) {
+  // The actually Iterator Values (that are updated every iteration).
+  auto idxTp = IndexType::get(itTp.getContext());
+  // TODO: handle batch dimension.
+  assert(itTp.getEncoding().getBatchLvlRank() == 0);
+  if (!itTp.isUnique()) {
+    // Segment high for non-unique iterator.
+    fields.push_back(idxTp);
+  }
+  fields.push_back(idxTp);
+  return success();
+}
+
 namespace {
 
 /// Sparse codegen rule for number of entries operator.
@@ -57,10 +71,114 @@ class ExtractIterSpaceConverter
   }
 };
 
+class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
+public:
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+  LogicalResult
+  matchAndRewrite(IterateOp op, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    if (!op.getCrdUsedLvls().empty())
+      return rewriter.notifyMatchFailure(
+          op, "non-empty coordinates list not implemented.");
+
+    Location loc = op.getLoc();
+
+    auto iterSpace = SparseIterationSpace::fromValues(
+        op.getIterSpace().getType(), adaptor.getIterSpace(), 0);
+
+    std::unique_ptr<SparseIterator> it =
+        iterSpace.extractIterator(rewriter, loc);
+
+    if (it->iteratableByFor()) {
+      auto [lo, hi] = it->genForCond(rewriter, loc);
+      Value step = constantIndex(rewriter, loc, 1);
+      SmallVector<Value> ivs;
+      for (ValueRange inits : adaptor.getInitArgs())
+        llvm::append_range(ivs, inits);
+      scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, ivs);
+
+      Block *loopBody = op.getBody();
+      OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
+      if (failed(typeConverter->convertSignatureArgs(
+              loopBody->getArgumentTypes(), bodyTypeMapping)))
+        return failure();
+      rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
+
+      rewriter.eraseBlock(forOp.getBody());
+      Region &dstRegion = forOp.getRegion();
+      rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
+
+      auto yieldOp =
+          llvm::cast<sparse_tensor::YieldOp>(forOp.getBody()->getTerminator());
+
+      rewriter.setInsertionPointToEnd(forOp.getBody());
+      // replace sparse_tensor.yield with scf.yield.
+      rewriter.create<scf::YieldOp>(loc, yieldOp.getResults());
+      rewriter.eraseOp(yieldOp);
+
+      const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
+      rewriter.replaceOp(op, forOp.getResults(), resultMapping);
+    } else {
+      SmallVector<Value> ivs;
+      llvm::append_range(ivs, it->getCursor());
+      for (ValueRange inits : adaptor.getInitArgs())
+        llvm::append_range(ivs, inits);
+
+      assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
+
+      TypeRange types = ValueRange(ivs).getTypes();
+      auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
+      SmallVector<Location> l(types.size(), op.getIterator().getLoc());
+
+      // Generates loop conditions.
+      Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
+      rewriter.setInsertionPointToStart(before);
+      ValueRange bArgs = before->getArguments();
+      auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
+      assert(remArgs.size() == adaptor.getInitArgs().size());
+      rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
+
+      // Generates loop body.
+      Block *loopBody = op.getBody();
+      OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
+      if (failed(typeConverter->convertSignatureArgs(
+              loopBody->getArgumentTypes(), bodyTypeMapping)))
+        return failure();
+      rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
+
+      Region &dstRegion = whileOp.getAfter();
+      // TODO: handle uses of coordinate!
+      rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
+      ValueRange aArgs = whileOp.getAfterArguments();
+      auto yieldOp = llvm::cast<sparse_tensor::YieldOp>(
+          whileOp.getAfterBody()->getTerminator());
+
+      rewriter.setInsertionPointToEnd(whileOp.getAfterBody());
+
+      aArgs = it->linkNewScope(aArgs);
+      ValueRange nx = it->forward(rewriter, loc);
+      SmallVector<Value> yields;
+      llvm::append_range(yields, nx);
+      llvm::append_range(yields, yieldOp.getResults());
+
+      // replace sparse_tensor.yield with scf.yield.
+      rewriter.eraseOp(yieldOp);
+      rewriter.create<scf::YieldOp>(loc, yields);
+
+      const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
+      rewriter.replaceOp(
+          op, whileOp.getResults().drop_front(it->getCursor().size()),
+          resultMapping);
+    }
+    return success();
+  }
+};
+
 } // namespace
 
 mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
   addConversion([](Type type) { return type; });
+  addConversion(convertIteratorType);
   addConversion(convertIterSpaceType);
 
   addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp,
@@ -74,5 +192,6 @@ mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
 
 void mlir::populateLowerSparseIterationToSCFPatterns(
     TypeConverter &converter, RewritePatternSet &patterns) {
-  patterns.add<ExtractIterSpaceConverter>(converter, patterns.getContext());
+  patterns.add<ExtractIterSpaceConverter, SparseIterateOpConverter>(
+      converter, patterns.getContext());
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index be8e15d6ae6f4..ef95fcc84bd90 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -331,6 +331,13 @@ class TrivialIterator : public ConcreteIterator {
   TrivialIterator(const SparseTensorLevel &stl)
       : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1) {}
 
+  TrivialIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl,
+                  Value posLo, Value posHi)
+      : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1), posLo(posLo),
+        posHi(posHi) {
+    seek(posLo);
+  }
+
   std::string getDebugInterfacePrefix() const override {
     return std::string("trivial<") + stl.toString() + ">";
   }
@@ -420,6 +427,14 @@ class DedupIterator : public ConcreteIterator {
       : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2) {
     assert(!stl.isUnique());
   }
+
+  DedupIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl,
+                Value posLo, Value posHi)
+      : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2), posHi(posHi) {
+    assert(!stl.isUnique());
+    seek({posLo, genSegmentHigh(b, l, posLo)});
+  }
+
   // For LLVM-style RTTI.
   static bool classof(const SparseIterator *from) {
     return from->kind == IterKind::kDedup;
@@ -1532,6 +1547,11 @@ SparseIterationSpace mlir::sparse_tensor::SparseIterationSpace::fromValues(
   return space;
 }
 
+std::unique_ptr<SparseIterator>
+SparseIterationSpace::extractIterator(OpBuilder &b, Location l) const {
+  return makeSimpleIterator(b, l, *this);
+}
+
 //===----------------------------------------------------------------------===//
 // SparseIterator factory functions.
 //===----------------------------------------------------------------------===//
@@ -1590,6 +1610,26 @@ sparse_tensor::makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl,
   return std::make_pair(std::move(stl), std::move(it));
 }
 
+std::unique_ptr<SparseIterator>
+sparse_tensor::makeSimpleIterator(OpBuilder &b, Location l,
+                                  const SparseIterationSpace &iterSpace) {
+  // assert(iterSpace.getSpaceDim() == 1);
+  std::unique_ptr<SparseIterator> ret;
+  if (!iterSpace.isUnique()) {
+    // We always dedupliate the non-unique level, but we should optimize it away
+    // if possible.
+    ret = std::make_unique<DedupIterator>(b, l, iterSpace.getLastLvl(),
+                                          iterSpace.getBoundLo(),
+                                          iterSpace.getBoundHi());
+  } else {
+    ret = std::make_unique<TrivialIterator>(b, l, iterSpace.getLastLvl(),
+                                            iterSpace.getBoundLo(),
+                                            iterSpace.getBoundHi());
+  }
+  ret->setSparseEmitStrategy(SparseEmitStrategy::kFunctional);
+  return ret;
+}
+
 std::unique_ptr<SparseIterator>
 sparse_tensor::makeSimpleIterator(const SparseTensorLevel &stl,
                                   SparseEmitStrategy strategy) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 17636af2b2f9d..91f363db93f1d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -132,6 +132,10 @@ class SparseIterationSpace {
   Value getBoundLo() const { return bound.first; }
   Value getBoundHi() const { return bound.second; }
 
+  // Extract an iterator to iterate over the sparse iteration space.
+  std::unique_ptr<SparseIterator> extractIterator(OpBuilder &b,
+                                                  Location l) const;
+
 private:
   SmallVector<std::unique_ptr<SparseTensorLevel>> lvls;
   std::pair<Value, Value> bound;
@@ -192,6 +196,13 @@ class SparseIterator {
     crd = nullptr;
   }
 
+  // Reconstructs a iteration space directly from the provided ValueRange.
+  static std::unique_ptr<SparseIterator>
+  fromValues(IteratorType dstTp, ValueRange values, unsigned tid);
+
+  // The inverse operation of `fromValues`.
+  SmallVector<Value> toValues() const { llvm_unreachable("Not implemented"); }
+
   //
   // Iterator properties.
   //
@@ -345,12 +356,21 @@ std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &b,
                                                          unsigned tid,
                                                          Level lvl);
 
-/// Helper function to create a TensorLevel object from given `tensor`.
+/// Helper function to create a TensorLevel object from given ValueRange.
 std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(LevelType lt, Value sz,
                                                          ValueRange buffers,
                                                          unsigned tid, Level l);
-/// Helper function to create a simple SparseIterator object that iterates
-/// over the SparseTensorLevel.
+
+/// Helper function to create a simple SparseIterator object that iterate
+/// over the entire iteration space.
+std::unique_ptr<SparseIterator>
+makeSimpleIterator(OpBuilder &b, Location l,
+                   const SparseIterationSpace &iterSpace);
+
+/// Helper function to create a simple SparseIterator object that iterate
+/// over the sparse tensor level.
+/// TODO: switch to `SparseIterationSpace` (which support N-D iterator) when
+/// feature complete.
 std::unique_ptr<SparseIterator> makeSimpleIterator(
     const SparseTensorLevel &stl,
     SparseEmitStrategy strategy = SparseEmitStrategy::kFunctional);
diff --git a/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir b/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir
index 5fcd661bb69b2..77a0e89dc7c81 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt %s --lower-sparse-iteration-to-scf | FileCheck %s
+// RUN: mlir-opt %s --sparse-space-collapse --lower-sparse-iteration-to-scf | FileCheck %s --check-prefix COLLAPSED
 
 #COO = #sparse_tensor.encoding<{
   map = (i, j) -> (
@@ -7,17 +8,44 @@
   )
 }>
 
-// CHECK-LABEL:   func.func @sparse_1D_space(
-// CHECK-SAME:      %[[VAL_0:.*]]: tensor<?x?xf32, #sparse{{[0-9]*}}>) -> !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0> {
-// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG:       %[[LVL_SIZE:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[C0]] : tensor<?x?xf32, #sparse{{[0-9]*}}>
-// CHECK:           %[[POS_MEM:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[CRD_MEM:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[POS_LO:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C0]]] : memref<?xindex>
-// CHECK:           %[[POS_HI:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C1]]] : memref<?xindex>
-// CHECK:           %[[ITER_SPACE:.*]] = builtin.unrealized_conversion_cast %[[POS_MEM]], %[[CRD_MEM]], %[[LVL_SIZE]], %[[POS_LO]], %[[POS_HI]]
-func.func @sparse_1D_space(%sp : tensor<?x?xf32, #COO>) -> !sparse_tensor.iter_space<#COO, lvls = 0> {
-  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<?x?xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
-  return %l1 : !sparse_tensor.iter_space<#COO, lvls = 0>
+// CHECK-LABEL:   @sparse_iteration_to_scf
+//                  // deduplication
+// CHECK:           scf.while {{.*}} {
+// CHECK:           } do {
+// CHECK:           }
+// CHECK:           scf.while {{.*}} {
+// CHECK:           } do {
+//                    // actual computation
+// CHECK:             scf.for {{.*}} {
+// CHECK:               arith.addi
+// CHECK:             }
+//                    // deduplication
+// CHECK:             scf.while {{.*}} {
+// CHECK:             } do {
+// CHECK:             }
+// CHECK:             scf.yield
+// CHECK:           }
+// CHECK:           return
+
+// COLLAPSED-LABEL:   @sparse_iteration_to_scf
+// COLLAPSED:           %[[RET:.*]] = scf.for {{.*}} {
+// COLLAPSED:             %[[VAL:.*]] = arith.addi
+// COLLAPSED:             scf.yield %[[VAL]] : index
+// COLLAPSED:           }
+// COLLAPSED:           return %[[RET]] : index
+func.func @sparse_iteration_to_scf(%sp : tensor<4x8xf32, #COO>) -> index {
+  %i = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0
+      : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
+  %r1 = sparse_tensor.iterate %it1 in %l1 iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
+    %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
+        : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1> -> !sparse_tensor.iter_space<#COO, lvls = 1>
+    %r2 = sparse_tensor.iterate %it2 in %l2 iter_args(%inner = %outer): !sparse_tensor.iter_space<#COO, lvls = 1 to 2> -> index {
+      %k = arith.addi %inner, %c1 : index
+      sparse_tensor.yield %k : index
+    }
+    sparse_tensor.yield %r2 : index
+  }
+  return %r1 : index
 }

@PeimingLiu PeimingLiu merged commit d6cc35f into llvm:main Jun 17, 2024
8 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants