Skip to content

[mlir][sparse] move all COO related methods into SparseTensorType #73881

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 30, 2023

Conversation

aartbik
Copy link
Contributor

@aartbik aartbik commented Nov 30, 2023

This centralizes all COO methods, and provides a cleaner API. Note that the "enc" only constructor is a temporary workaround the need for COO methods inside the "enc" only storage specifier.

This centralizes all COO methods, and provides a cleaner API.
Note that the "enc" only constructor is a temporary workaround
the need for COO methods inside the "enc" only storage specifier.
@llvmbot llvmbot added mlir:sparse Sparse compiler in MLIR mlir labels Nov 30, 2023
@llvmbot
Copy link
Member

llvmbot commented Nov 30, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sparse

Author: Aart Bik (aartbik)

Changes

This centralizes all COO methods, and provides a cleaner API. Note that the "enc" only constructor is a temporary workaround the need for COO methods inside the "enc" only storage specifier.


Full diff: https://github.com/llvm/llvm-project/pull/73881.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h (-13)
  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h (+18-2)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+35-44)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp (+3-4)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (+4-5)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h (+1-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+2-2)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 28dfdbdcf89b5bf..5e523ec428aefb9 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -89,19 +89,6 @@ inline MemRefType getMemRefType(T &&t) {
 /// Returns null-attribute for any type without an encoding.
 SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
 
-/// Returns true iff the given sparse tensor encoding attribute has a trailing
-/// COO region starting at the given level.
-bool isCOOType(SparseTensorEncodingAttr enc, Level startLvl, bool isUnique);
-
-/// Returns true iff the given type is a COO type where the last level
-/// is unique.
-bool isUniqueCOOType(Type tp);
-
-/// Returns the starting level for a trailing COO region that spans
-/// at least two levels.  If no such COO region is found, then returns
-/// the level-rank.
-Level getCOOStart(SparseTensorEncodingAttr enc);
-
 /// Returns true iff MLIR operand has any sparse operand.
 inline bool hasAnySparseOperand(Operation *op) {
   return llvm::any_of(op->getOperands().getTypes(), [](Type t) {
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index dc520e390de293d..4c98129744bcd94 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -60,6 +60,12 @@ class SparseTensorType {
       : SparseTensorType(
             RankedTensorType::get(stp.getShape(), stp.getElementType(), enc)) {}
 
+  // TODO: remove?
+  SparseTensorType(SparseTensorEncodingAttr enc)
+      : SparseTensorType(RankedTensorType::get(
+            SmallVector<Size>(enc.getDimRank(), ShapedType::kDynamic),
+            Float32Type::get(enc.getContext()), enc)) {}
+
   SparseTensorType &operator=(const SparseTensorType &) = delete;
   SparseTensorType(const SparseTensorType &) = default;
 
@@ -234,9 +240,9 @@ class SparseTensorType {
                                        CrdTransDirectionKind::dim2lvl);
   }
 
+  /// Returns the type with an identity mapping.
   RankedTensorType getDemappedType() const {
-    auto lvlShape = getLvlShape();
-    return RankedTensorType::get(lvlShape, rtp.getElementType(),
+    return RankedTensorType::get(getLvlShape(), getElementType(),
                                  enc.withoutDimToLvl());
   }
 
@@ -311,6 +317,16 @@ class SparseTensorType {
     return IndexType::get(getContext());
   }
 
+  /// Returns true iff this sparse tensor type has a trailing
+  /// COO region starting at the given level. By default, it
+  /// tests for a unique COO type at top level.
+  bool isCOOType(Level startLvl = 0, bool isUnique = true) const;
+
+  /// Returns the starting level of this sparse tensor type for a
+  /// trailing COO region that spans **at least** two levels. If
+  /// no such COO region is found, then returns the level-rank.
+  Level getCOOStart() const;
+
   /// Returns [un]ordered COO type for this sparse tensor type.
   RankedTensorType getCOOType(bool ordered) const;
 
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index d4f8afdd62f2383..7dc4fc4f8570d60 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -66,7 +66,7 @@ void StorageLayout::foreachField(
         callback) const {
   const auto lvlTypes = enc.getLvlTypes();
   const Level lvlRank = enc.getLvlRank();
-  const Level cooStart = getCOOStart(enc);
+  const Level cooStart = SparseTensorType(enc).getCOOStart();
   const Level end = cooStart == lvlRank ? cooStart : cooStart + 1;
   FieldIndex fieldIdx = kDataFieldStartingIdx;
   // Per-level storage.
@@ -158,7 +158,7 @@ StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
   unsigned stride = 1;
   if (kind == SparseTensorFieldKind::CrdMemRef) {
     assert(lvl.has_value());
-    const Level cooStart = getCOOStart(enc);
+    const Level cooStart = SparseTensorType(enc).getCOOStart();
     const Level lvlRank = enc.getLvlRank();
     if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
       lvl = cooStart;
@@ -710,6 +710,28 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
 // SparseTensorType Methods.
 //===----------------------------------------------------------------------===//
 
+bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl, bool isUnique) const {
+  if (!hasEncoding())
+    return false;
+  if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
+    return false;
+  for (Level l = startLvl + 1; l < lvlRank; ++l)
+    if (!isSingletonLvl(l))
+      return false;
+  // If isUnique is true, then make sure that the last level is unique,
+  // that is, lvlRank == 1 (unique the only compressed) and lvlRank > 1
+  // (unique on the last singleton).
+  return !isUnique || isUniqueLvl(lvlRank - 1);
+}
+
+Level mlir::sparse_tensor::SparseTensorType::getCOOStart() const {
+  if (lvlRank > 1)
+    for (Level l = 0; l < lvlRank - 1; l++)
+      if (isCOOType(l, /*isUnique=*/false))
+        return l;
+  return lvlRank;
+}
+
 RankedTensorType
 mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
   SmallVector<LevelType> lvlTypes;
@@ -859,25 +881,6 @@ bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) {
   return !coeffientMap.empty();
 }
 
-bool mlir::sparse_tensor::isCOOType(SparseTensorEncodingAttr enc,
-                                    Level startLvl, bool isUnique) {
-  if (!enc ||
-      !(enc.isCompressedLvl(startLvl) || enc.isLooseCompressedLvl(startLvl)))
-    return false;
-  const Level lvlRank = enc.getLvlRank();
-  for (Level l = startLvl + 1; l < lvlRank; ++l)
-    if (!enc.isSingletonLvl(l))
-      return false;
-  // If isUnique is true, then make sure that the last level is unique,
-  // that is, lvlRank == 1 (unique the only compressed) and lvlRank > 1
-  // (unique on the last singleton).
-  return !isUnique || enc.isUniqueLvl(lvlRank - 1);
-}
-
-bool mlir::sparse_tensor::isUniqueCOOType(Type tp) {
-  return isCOOType(getSparseTensorEncoding(tp), 0, /*isUnique=*/true);
-}
-
 bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) {
   auto hasNonIdentityMap = [](Value v) {
     auto stt = tryGetSparseTensorType(v);
@@ -888,17 +891,6 @@ bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) {
          llvm::any_of(op->getResults(), hasNonIdentityMap);
 }
 
-Level mlir::sparse_tensor::getCOOStart(SparseTensorEncodingAttr enc) {
-  // We only consider COO region with at least two levels for the purpose
-  // of AOS storage optimization.
-  const Level lvlRank = enc.getLvlRank();
-  if (lvlRank > 1)
-    for (Level l = 0; l < lvlRank - 1; l++)
-      if (isCOOType(enc, l, /*isUnique=*/false))
-        return l;
-  return lvlRank;
-}
-
 Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
   if (enc) {
     assert(enc.isPermutation() && "Non permutation map not supported");
@@ -1013,7 +1005,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
     return op->emitError("the sparse-tensor must have the identity mapping");
 
   // Verifies the trailing COO.
-  Level cooStartLvl = getCOOStart(stt.getEncoding());
+  Level cooStartLvl = stt.getCOOStart();
   if (cooStartLvl < stt.getLvlRank()) {
     // We only supports trailing COO for now, must be the last input.
     auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
@@ -1309,34 +1301,34 @@ OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
 }
 
 LogicalResult ToPositionsOp::verify() {
-  auto e = getSparseTensorEncoding(getTensor().getType());
+  auto stt = getSparseTensorType(getTensor());
   if (failed(lvlIsInBounds(getLevel(), getTensor())))
     return emitError("requested level is out of bounds");
-  if (failed(isMatchingWidth(getResult(), e.getPosWidth())))
+  if (failed(isMatchingWidth(getResult(), stt.getPosWidth())))
     return emitError("unexpected type for positions");
   return success();
 }
 
 LogicalResult ToCoordinatesOp::verify() {
-  auto e = getSparseTensorEncoding(getTensor().getType());
+  auto stt = getSparseTensorType(getTensor());
   if (failed(lvlIsInBounds(getLevel(), getTensor())))
     return emitError("requested level is out of bounds");
-  if (failed(isMatchingWidth(getResult(), e.getCrdWidth())))
+  if (failed(isMatchingWidth(getResult(), stt.getCrdWidth())))
     return emitError("unexpected type for coordinates");
   return success();
 }
 
 LogicalResult ToCoordinatesBufferOp::verify() {
-  auto e = getSparseTensorEncoding(getTensor().getType());
-  if (getCOOStart(e) >= e.getLvlRank())
+  auto stt = getSparseTensorType(getTensor());
+  if (stt.getCOOStart() >= stt.getLvlRank())
     return emitError("expected sparse tensor with a COO region");
   return success();
 }
 
 LogicalResult ToValuesOp::verify() {
-  auto ttp = getRankedTensorType(getTensor());
+  auto stt = getSparseTensorType(getTensor());
   auto mtp = getMemRefType(getResult());
-  if (ttp.getElementType() != mtp.getElementType())
+  if (stt.getElementType() != mtp.getElementType())
     return emitError("unexpected mismatch in element types");
   return success();
 }
@@ -1660,9 +1652,8 @@ LogicalResult ReorderCOOOp::verify() {
   SparseTensorType srcStt = getSparseTensorType(getInputCoo());
   SparseTensorType dstStt = getSparseTensorType(getResultCoo());
 
-  if (!isCOOType(srcStt.getEncoding(), 0, /*isUnique=*/true) ||
-      !isCOOType(dstStt.getEncoding(), 0, /*isUnique=*/true))
-    emitError("Unexpected non-COO sparse tensors");
+  if (!srcStt.isCOOType() || !dstStt.isCOOType())
+    emitError("Expected COO sparse tensors only");
 
   if (!srcStt.hasSameDimToLvl(dstStt))
     emitError("Unmatched dim2lvl map between input and result COO");
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index a245344755f0404..26f015ce6ec64f7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -412,8 +412,7 @@ void LoopEmitter::initializeLoopEmit(
     auto stt = getSparseTensorType(tensor);
     const Level lvlRank = stt.getLvlRank();
     const auto shape = rtp.getShape();
-    const auto enc = getSparseTensorEncoding(rtp);
-    const Level cooStart = enc ? getCOOStart(enc) : lvlRank;
+    const Level cooStart = stt.getCOOStart();
 
     SmallVector<Value> lvlSzs;
     for (Level l = 0; l < stt.getLvlRank(); l++) {
@@ -457,8 +456,8 @@ void LoopEmitter::initializeLoopEmit(
     // values.
     // Delegates extra output initialization to clients.
     bool isOutput = isOutputTensor(t);
-    Type elementType = rtp.getElementType();
-    if (!enc) {
+    Type elementType = stt.getElementType();
+    if (!stt.hasEncoding()) {
       // Non-annotated dense tensors.
       BaseMemRefType denseTp = MemRefType::get(shape, elementType);
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index e9062b49435f5b7..18b2bb0819e2642 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -194,7 +194,7 @@ static void createAllocFields(OpBuilder &builder, Location loc,
       valHeuristic =
           builder.create<arith::MulIOp>(loc, valHeuristic, lvlSizesValues[lvl]);
   } else if (sizeHint) {
-    if (getCOOStart(stt.getEncoding()) == 0) {
+    if (stt.getCOOStart() == 0) {
       posHeuristic = constantIndex(builder, loc, 2);
       crdHeuristic = builder.create<arith::MulIOp>(
           loc, constantIndex(builder, loc, lvlRank), sizeHint); // AOS
@@ -657,8 +657,7 @@ struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
 
     // Should have been verified.
     assert(dstStt.isAllOrdered() && !srcStt.isAllOrdered() &&
-           isUniqueCOOType(srcStt.getRankedTensorType()) &&
-           isUniqueCOOType(dstStt.getRankedTensorType()));
+           dstStt.isCOOType() && srcStt.isCOOType());
     assert(dstStt.hasSameDimToLvl(srcStt));
 
     // We don't need a mutable descriptor here as we perform sorting in-place.
@@ -1317,7 +1316,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
     Value posBack = c0; // index to the last value in the position array
     Value memSize = c1; // memory size for current array
 
-    Level trailCOOStart = getCOOStart(stt.getEncoding());
+    Level trailCOOStart = stt.getCOOStart();
     Level trailCOORank = stt.getLvlRank() - trailCOOStart;
     // Sets up SparseTensorSpecifier.
     for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
@@ -1454,7 +1453,7 @@ struct SparseNewConverter : public OpConversionPattern<NewOp> {
     const auto dstTp = getSparseTensorType(op.getResult());
     // Creating COO with NewOp is handled by direct IR codegen. All other cases
     // are handled by rewriting.
-    if (!dstTp.hasEncoding() || getCOOStart(dstTp.getEncoding()) != 0)
+    if (!dstTp.hasEncoding() || dstTp.getCOOStart() != 0)
       return failure();
 
     // Implement as follows:
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.cpp
index 1c6d7bebe37e46c..3ab4157475cd4c2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.cpp
@@ -103,7 +103,7 @@ void SparseTensorSpecifier::setSpecifierField(OpBuilder &builder, Location loc,
 
 Value sparse_tensor::SparseTensorDescriptor::getCrdMemRefOrView(
     OpBuilder &builder, Location loc, Level lvl) const {
-  const Level cooStart = getCOOStart(rType.getEncoding());
+  const Level cooStart = rType.getCOOStart();
   if (lvl < cooStart)
     return getMemRefField(SparseTensorFieldKind::CrdMemRef, lvl);
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h
index 4bd700eef522e04..5c7d8aa4c9d9678 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h
@@ -137,7 +137,7 @@ class SparseTensorDescriptorImpl {
   }
 
   Value getAOSMemRef() const {
-    const Level cooStart = getCOOStart(rType.getEncoding());
+    const Level cooStart = rType.getCOOStart();
     assert(cooStart < rType.getLvlRank());
     return getMemRefField(SparseTensorFieldKind::CrdMemRef, cooStart);
   }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 2bd129b85ea5416..4fc692f2fe9ddc2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1180,8 +1180,7 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     auto stt = getSparseTensorType(op.getResult());
-    auto enc = stt.getEncoding();
-    if (!stt.hasEncoding() || getCOOStart(enc) == 0)
+    if (!stt.hasEncoding() || stt.getCOOStart() == 0)
       return failure();
 
     // Implement the NewOp as follows:
@@ -1192,6 +1191,7 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
     RankedTensorType cooTp = stt.getCOOType(/*ordered=*/true);
     Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource());
     Value convert = cooTensor;
+    auto enc = stt.getEncoding();
     if (!stt.isPermutation()) { // demap coo, demap dstTp
       auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl();
       convert = rewriter.create<ReinterpretMapOp>(loc, coo, convert);

Copy link

github-actions bot commented Nov 30, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@aartbik aartbik merged commit 5b72950 into llvm:main Nov 30, 2023
@aartbik aartbik deleted the bik branch November 30, 2023 17:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants