Skip to content

[mlir][sparse] move toCOOType into SparseTensorType class #73708

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 29, 2023

Conversation

aartbik
Copy link
Contributor

@aartbik aartbik commented Nov 28, 2023

Migrates dangling convenience method into proper SparseTensorType class. Also cleans up some details (picking right dim2lvl/lvl2dim). Removes more dead code.

Migrates dangling convenience method into proper SparseTensorType
class. Also cleans up some details (picking right dim2lvl/lvl2dim).
Removes more dead code.
@llvmbot llvmbot added mlir:sparse Sparse compiler in MLIR mlir labels Nov 28, 2023
@llvmbot
Copy link
Member

llvmbot commented Nov 28, 2023

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Aart Bik (aartbik)

Changes

Migrates dangling convenience method into proper SparseTensorType class. Also cleans up some details (picking right dim2lvl/lvl2dim). Removes more dead code.


Full diff: https://github.com/llvm/llvm-project/pull/73708.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h (-4)
  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h (+9-20)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+27-2)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp (+2-5)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 517c286e0206997..28dfdbdcf89b5bf 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -102,10 +102,6 @@ bool isUniqueCOOType(Type tp);
 /// the level-rank.
 Level getCOOStart(SparseTensorEncodingAttr enc);
 
-/// Helper to setup a COO type.
-RankedTensorType getCOOFromTypeWithOrdering(RankedTensorType src,
-                                            AffineMap ordering, bool ordered);
-
 /// Returns true iff MLIR operand has any sparse operand.
 inline bool hasAnySparseOperand(Operation *op) {
   return llvm::any_of(op->getOperands().getTypes(), [](Type t) {
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 4eb666d76cd2d6f..dc520e390de293d 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -64,18 +64,14 @@ class SparseTensorType {
   SparseTensorType(const SparseTensorType &) = default;
 
   //
-  // Factory methods.
+  // Factory methods to construct a new `SparseTensorType`
+  // with the same dimension-shape and element type.
   //
 
-  /// Constructs a new `SparseTensorType` with the same dimension-shape
-  /// and element type, but with the encoding replaced by the given encoding.
   SparseTensorType withEncoding(SparseTensorEncodingAttr newEnc) const {
     return SparseTensorType(rtp, newEnc);
   }
 
-  /// Constructs a new `SparseTensorType` with the same dimension-shape
-  /// and element type, but with the encoding replaced by
-  /// `getEncoding().withDimToLvl(dimToLvl)`.
   SparseTensorType withDimToLvl(AffineMap dimToLvl) const {
     return withEncoding(enc.withDimToLvl(dimToLvl));
   }
@@ -88,23 +84,14 @@ class SparseTensorType {
     return withDimToLvl(dimToLvlSTT.getEncoding());
   }
 
-  /// Constructs a new `SparseTensorType` with the same dimension-shape
-  /// and element type, but with the encoding replaced by
-  /// `getEncoding().withoutDimToLvl()`.
   SparseTensorType withoutDimToLvl() const {
     return withEncoding(enc.withoutDimToLvl());
   }
 
-  /// Constructs a new `SparseTensorType` with the same dimension-shape
-  /// and element type, but with the encoding replaced by
-  /// `getEncoding().withBitWidths(posWidth, crdWidth)`.
   SparseTensorType withBitWidths(unsigned posWidth, unsigned crdWidth) const {
     return withEncoding(enc.withBitWidths(posWidth, crdWidth));
   }
 
-  /// Constructs a new `SparseTensorType` with the same dimension-shape
-  /// and element type, but with the encoding replaced by
-  /// `getEncoding().withoutBitWidths()`.
   SparseTensorType withoutBitWidths() const {
     return withEncoding(enc.withoutBitWidths());
   }
@@ -118,10 +105,6 @@ class SparseTensorType {
     return withEncoding(enc.withoutDimSlices());
   }
 
-  //
-  // Other methods.
-  //
-
   /// Allow implicit conversion to `RankedTensorType`, `ShapedType`,
   /// and `Type`.  These are implicit to help alleviate the impedance
   /// mismatch for code that has not been converted to use `SparseTensorType`
@@ -170,7 +153,6 @@ class SparseTensorType {
 
   Type getElementType() const { return rtp.getElementType(); }
 
-  /// Returns the encoding (or the null-attribute for dense-tensors).
   SparseTensorEncodingAttr getEncoding() const { return enc; }
 
   //
@@ -204,6 +186,10 @@ class SparseTensorType {
   /// (This is always true for dense-tensors.)
   bool isIdentity() const { return enc.isIdentity(); }
 
+  //
+  // Other methods.
+  //
+
   /// Returns the dimToLvl mapping (or the null-map for the identity).
   /// If you intend to compare the results of this method for equality,
   /// see `hasSameDimToLvl` instead.
@@ -325,6 +311,9 @@ class SparseTensorType {
     return IndexType::get(getContext());
   }
 
+  /// Returns [un]ordered COO type for this sparse tensor type.
+  RankedTensorType getCOOType(bool ordered) const;
+
 private:
   // These two must be const, to ensure coherence of the memoized fields.
   const RankedTensorType rtp;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index ff2930008fa093f..edf7df3cfedabba 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -36,7 +36,7 @@ using namespace mlir;
 using namespace mlir::sparse_tensor;
 
 //===----------------------------------------------------------------------===//
-// Local convenience methods.
+// Local Convenience Methods.
 //===----------------------------------------------------------------------===//
 
 static constexpr bool acceptBitWidth(unsigned bitWidth) {
@@ -711,7 +711,32 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
 }
 
 //===----------------------------------------------------------------------===//
-// Convenience methods.
+// SparseTensorType SparseTensorType Methods.
+//===----------------------------------------------------------------------===//
+
+RankedTensorType
+mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
+  SmallVector<LevelType> lvlTypes;
+  lvlTypes.reserve(lvlRank);
+  // An unordered and non-unique compressed level at beginning.
+  // If this is also the last level, then it is unique.
+  lvlTypes.push_back(
+      *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
+  if (lvlRank > 1) {
+    // Followed by unordered non-unique n-2 singleton levels.
+    std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
+                *buildLevelType(LevelFormat::Singleton, ordered, false));
+    // Ends by a unique singleton level unless the lvlRank is 1.
+    lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
+  }
+  auto enc = SparseTensorEncodingAttr::get(getContext(), lvlTypes,
+                                           getDimToLvl(), getLvlToDim(),
+                                           getPosWidth(), getCrdWidth());
+  return RankedTensorType::get(getDimShape(), getElementType(), enc);
+}
+
+//===----------------------------------------------------------------------===//
+// Convenience Methods.
 //===----------------------------------------------------------------------===//
 
 SparseTensorEncodingAttr
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
index d8769eacc44f39b..c8e77f7de48300e 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
@@ -25,9 +25,7 @@ sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
   Location loc = op.getLoc();
   Type finalTp = op->getOpResult(0).getType();
   SparseTensorType dstStt(finalTp.cast<RankedTensorType>());
-
-  Type srcCOOTp = getCOOFromTypeWithOrdering(
-      dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/false);
+  Type srcCOOTp = dstStt.getCOOType(/*ordered=*/false);
 
   // Clones the original operation but changing the output to an unordered COO.
   Operation *cloned = rewriter.clone(*op.getOperation());
@@ -37,8 +35,7 @@ sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
   Value srcCOO = cloned->getOpResult(0);
 
   // -> sort
-  Type dstCOOTp = getCOOFromTypeWithOrdering(
-      dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/true);
+  Type dstCOOTp = dstStt.getCOOType(/*ordered=*/true);
   Value dstCOO = rewriter.create<ReorderCOOOp>(
       loc, dstCOOTp, srcCOO, SparseTensorSortKind::HybridQuickSort);
 

@aartbik aartbik merged commit 4528808 into llvm:main Nov 29, 2023
@aartbik aartbik deleted the bik branch November 29, 2023 00:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants