Skip to content

[mlir][sparse] introduce sparse_tensor.coiterate operation. #101100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 31, 2024

Conversation

PeimingLiu
Copy link
Member

@PeimingLiu PeimingLiu commented Jul 29, 2024

This PR introduces sparse_tensor.coiterate operation, which represents a loop that traverses multiple sparse iteration space.

@llvmbot
Copy link
Member

llvmbot commented Jul 29, 2024

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Peiming Liu (PeimingLiu)

Changes

Patch is 28.97 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/101100.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h (+40-15)
  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td (+6-5)
  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td (+120-3)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+228-65)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/roundtrip.mlir (+34-1)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 68ca036121520..388efd1c454b1 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -61,37 +61,62 @@ struct COOSegment {
 /// A simple wrapper to encode a bitset of (at most 64) levels, currently used
 /// by `sparse_tensor.iterate` operation for the set of levels on which the
 /// coordinates should be loaded.
-class LevelSet {
-  uint64_t bits = 0;
+class I64BitSet {
+  uint64_t storage = 0;
 
 public:
-  LevelSet() = default;
-  explicit LevelSet(uint64_t bits) : bits(bits) {}
-  operator uint64_t() const { return bits; }
+  using const_set_bits_iterator = llvm::const_set_bits_iterator_impl<I64BitSet>;
+  const_set_bits_iterator begin() const {
+    return const_set_bits_iterator(*this);
+  }
+  const_set_bits_iterator end() const {
+    return const_set_bits_iterator(*this, -1);
+  }
+  iterator_range<const_set_bits_iterator> bits() const {
+    return make_range(begin(), end());
+  }
+
+  I64BitSet() = default;
+  explicit I64BitSet(uint64_t bits) : storage(bits) {}
+  operator uint64_t() const { return storage; }
 
-  LevelSet &set(unsigned i) {
+  I64BitSet &set(unsigned i) {
     assert(i < 64);
-    bits |= static_cast<uint64_t>(0x01u) << i;
+    storage |= static_cast<uint64_t>(0x01u) << i;
     return *this;
   }
 
-  LevelSet &operator|=(LevelSet lhs) {
-    bits |= static_cast<uint64_t>(lhs);
+  I64BitSet &operator|=(I64BitSet lhs) {
+    storage |= static_cast<uint64_t>(lhs);
     return *this;
   }
 
-  LevelSet &lshift(unsigned offset) {
-    bits = bits << offset;
+  I64BitSet &lshift(unsigned offset) {
+    storage = storage << offset;
     return *this;
   }
 
+  // Needed by `llvm::const_set_bits_iterator_impl`.
+  int find_first() const { return min(); }
+  int find_next(unsigned prev) const {
+    if (prev >= max())
+      return -1;
+
+    uint64_t b = storage >> (prev + 1);
+    if (b == 0)
+      return -1;
+
+    return llvm::countr_zero(b) + prev + 1;
+  }
+
   bool operator[](unsigned i) const {
     assert(i < 64);
-    return (bits & (1 << i)) != 0;
+    return (storage & (1 << i)) != 0;
   }
-  unsigned max() const { return 64 - llvm::countl_zero(bits); }
-  unsigned count() const { return llvm::popcount(bits); }
-  bool empty() const { return bits == 0; }
+  unsigned min() const { return llvm::countr_zero(storage); }
+  unsigned max() const { return 64 - llvm::countl_zero(storage); }
+  unsigned count() const { return llvm::popcount(storage); }
+  bool empty() const { return storage == 0; }
 };
 
 } // namespace sparse_tensor
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 69b212cce4ceb..cb6c1b63e4e4b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -24,16 +24,17 @@ class SparseTensor_Attr<string name,
 // sparse tensor levels.
 //===----------------------------------------------------------------------===//
 
-def LevelSetAttr :
-    TypedAttrBase<
-      I64, "IntegerAttr",
+def I64BitSetAttr : TypedAttrBase<I64, "IntegerAttr",
       And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
            CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType().isInteger(64)">]>,
       "LevelSet attribute"> {
-  let returnType = [{::mlir::sparse_tensor::LevelSet}];
-  let convertFromStorage = [{::mlir::sparse_tensor::LevelSet($_self.getValue().getZExtValue())}];
+  let returnType = [{::mlir::sparse_tensor::I64BitSet}];
+  let convertFromStorage = [{::mlir::sparse_tensor::I64BitSet($_self.getValue().getZExtValue())}];
 }
 
+def I64BitSetArrayAttr :
+    TypedArrayAttrBase<I64BitSetAttr, "I64BitSet array attribute">;
+
 //===----------------------------------------------------------------------===//
 // These attributes are just like `IndexAttr` except that they clarify whether
 // the index refers to a dimension (an axis of the semantic tensor) or a level
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index f31df080d7811..394934cfbd4ca 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1306,7 +1306,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
 
 def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
     ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
-                 "ForeachOp", "IterateOp"]>]> {
+                 "ForeachOp", "IterateOp", "CoIterateOp"]>]> {
   let summary = "Yield from sparse_tensor set-like operations";
   let description = [{
       Yields a value from within a `binary`, `unary`, `reduce`,
@@ -1604,14 +1604,14 @@ def IterateOp : SparseTensor_Op<"iterate",
 
   let arguments = (ins AnySparseIterSpace:$iterSpace,
                        Variadic<AnyType>:$initArgs,
-                       LevelSetAttr:$crdUsedLvls);
+                       I64BitSetAttr:$crdUsedLvls);
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$region);
 
   let skipDefaultBuilders = 1;
   let builders = [
     OpBuilder<(ins "Value":$iterSpace, "ValueRange":$initArgs)>,
-    OpBuilder<(ins "Value":$iterSpace, "ValueRange":$initArgs, "LevelSet" :$crdUsedLvls)>
+    OpBuilder<(ins "Value":$iterSpace, "ValueRange":$initArgs, "I64BitSet" :$crdUsedLvls)>
   ];
 
   let extraClassDeclaration = [{
@@ -1644,6 +1644,123 @@ def IterateOp : SparseTensor_Op<"iterate",
   let hasCustomAssemblyFormat = 1;
 }
 
+def SparseTensor_CoIterateOp : SparseTensor_Op<"coiterate",
+    [AttrSizedOperandSegments,
+     SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">,
+     RecursiveMemoryEffects]> {
+  let summary = "CoIterates over a set of sparse iteration spaces";
+  let description = [{
+      The `sparse_tensor.coiterate` operation represents a loop (nest) over
+      the a set of iteration spaces.
+      The operation can have multiple regions, with each of them defining a
+      case to compute a result at the current iterations. The case condition
+      is defined solely based on the pattern of specified iterators.
+      For example:
+      ```mlir
+      %ret = sparse_tensor.coiterate (%sp1, %sp2) at(%coord) iter_args(%arg = %init)
+           : (!sparse_tensor.iter_space<#CSR, lvls = 0>,
+              !sparse_tensor.iter_space<#COO, lvls = 0>)
+           -> index
+      case %it1, _ {
+        // %coord is specifed in space %sp1 but *NOT* specified in space %sp2.
+      }
+      case %it1, %it2 {
+        // %coord is specifed in *BOTH* spaces %sp1 and %sp2.
+      }
+      ```
+
+      `sparse_tensor.coiterate` can also operate on loop-carried variables.
+      It returns the final values after loop termination.
+      The initial values of the variables are passed as additional SSA operands
+      to the iterator SSA value and used coordinate SSA values.
+      Each operation region has variadic arguments for specified (used), one argument
+      for each loop-carried variable, representing the value of the variable
+      at the current iteration, followed by a list of arguments for iterators.
+      The body region must contain exactly one block that terminates with
+      `sparse_tensor.yield`.
+
+      The results of an `sparse_tensor.coiterate` hold the final values after
+      the last iteration. If the `sparse_tensor.coiterate` defines any values,
+      a yield must be explicitly present in every region defined in the operation.
+      The number and types of the `sparse_tensor.coiterate` results must match
+      the initial values in the iter_args binding and the yield operands.
+
+
+      A `sparse_tensor.coiterate` example that does elementwise addition between two
+      sparse vectors.
+
+
+      ```mlir
+      %ret = sparse_tensor.coiterate (%sp1, %sp2) at(%coord) iter_args(%arg = %init)
+           : (!sparse_tensor.iter_space<#CSR, lvls = 0>,
+              !sparse_tensor.iter_space<#CSR, lvls = 0>)
+           -> tensor<?xindex, #CSR>
+      case %it1, _ {
+         // v = v1 + 0 = v1
+         %v1 = sparse_tensor.extract_value %t1 at %it1 : index
+         %yield = sparse_tensor.insert %v1 into %arg[%coord]
+         sparse_tensor.yield %yield
+      }
+      case _, %it2 {
+         // v = v2 + 0 = v2
+         %v2 = sparse_tensor.extract_value %t2 at %it2 : index
+         %yield = sparse_tensor.insert %v1 into %arg[%coord]
+         sparse_tensor.yield %yield
+      }
+      case %it1, %it2 {
+         // v = v1 + v2
+         %v1 = sparse_tensor.extract_value %t1 at %it1 : index
+         %v2 = sparse_tensor.extract_value %t2 at %it2 : index
+         %v = arith.addi %v1, %v2 : index
+         %yield = sparse_tensor.insert %v into %arg[%coord]
+         sparse_tensor.yield %yield
+      }
+      ```
+  }];
+
+  let arguments = (ins Variadic<AnySparseIterSpace>:$iterSpaces,
+                       Variadic<AnyType>:$initArgs,
+                       I64BitSetAttr:$crdUsedLvls,
+                       I64BitSetArrayAttr:$cases);
+  let results = (outs Variadic<AnyType>:$results);
+  let regions = (region VariadicRegion<SizedRegion<1>>:$caseRegions);
+
+  let extraClassDeclaration = [{
+    unsigned getSpaceDim() {
+      return llvm::cast<::mlir::sparse_tensor::IterSpaceType>(
+                 getIterSpaces().front().getType())
+          .getSpaceDim();
+    }
+    I64BitSet getRegionDefinedSpace(unsigned regionIdx) {
+      return I64BitSet(llvm::cast<IntegerAttr>(getCases()[regionIdx])
+                           .getValue().getZExtValue());
+    }
+    // The block arguments starts with referenced coordinates, follows by
+    // user-provided iteration arguments and ends with iterators.
+    Block::BlockArgListType getCrds(unsigned regionIdx) {
+      return getRegion(regionIdx).getArguments()
+          .take_front(getCrdUsedLvls().count());
+    }
+    unsigned getNumRegionIterArgs(unsigned regionIdx) {
+      return getInitArgs().size();
+    }
+    Block::BlockArgListType getRegionIterArgs(unsigned regionIdx) {
+      return getRegion(regionIdx).getArguments()
+          .slice(getCrdUsedLvls().count(), getNumRegionIterArgs(regionIdx));
+    }
+    Block::BlockArgListType getRegionIterators(unsigned regionIdx) {
+      return getRegion(regionIdx).getArguments()
+          .take_back(getRegionDefinedSpace(regionIdx).count());
+    }
+  }];
+
+  // TODO:
+  // let hasVerifier = 1;
+  // let hasRegionVerifier = 1;
+  // let hasCanonicalizer = 1;
+  let hasCustomAssemblyFormat = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Sparse Tensor Debugging and Test-Only Operations.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 616e91ae04055..33e3e839f8832 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2131,10 +2131,82 @@ static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
   printLevelRange(p, lo, hi);
 }
 
+/// Parses a list of `optional` defined list in the form of
+/// "(%val0, _, %val1, ...)", where `_` is used to annotate that the
+/// corresponding value is not defined (e.g., to represent an undefined
+/// coordinate in the sparse iteration space).
+static ParseResult parseOptionalDefinedList(
+    OpAsmParser &parser, OperationState &state, I64BitSet &definedSet,
+    SmallVectorImpl<OpAsmParser::Argument> &definedArgs,
+    unsigned maxCnt = std::numeric_limits<unsigned>::max(),
+    OpAsmParser::Delimiter delimiter = OpAsmParser::Delimiter::Paren) {
+  unsigned cnt = 0;
+  ParseResult crdList =
+      parser.parseCommaSeparatedList(delimiter, [&]() -> ParseResult {
+        if (parser.parseOptionalKeyword("_")) {
+          if (parser.parseArgument(definedArgs.emplace_back()))
+            return failure();
+          definedSet.set(cnt);
+        }
+        cnt += 1;
+        return success();
+      });
+
+  if (cnt > maxCnt)
+    return parser.emitError(parser.getNameLoc(),
+                            "parsed more value than expected.");
+
+  if (failed(crdList)) {
+    return parser.emitError(
+        parser.getNameLoc(),
+        "expecting SSA value or \"_\" for level coordinates");
+  }
+  assert(definedArgs.size() == definedSet.count());
+  return success();
+}
+
+static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size,
+                                     Block::BlockArgListType blocksArgs,
+                                     I64BitSet definedSet) {
+  if (definedSet.empty())
+    return;
+
+  for (unsigned i = 0; i < size; i++) {
+    if (definedSet[i]) {
+      p << blocksArgs.front();
+      blocksArgs = blocksArgs.drop_front();
+    } else {
+      p << "_";
+    }
+    if (i != size - 1)
+      p << ", ";
+  }
+  assert(blocksArgs.empty());
+}
+
+static ParseResult
+parseUsedCoordList(OpAsmParser &parser, OperationState &state,
+                   SmallVectorImpl<OpAsmParser::Argument> &coords) {
+  // Parse "at(%crd0, _, ...)"
+  I64BitSet crdUsedLvlSet;
+  if (succeeded(parser.parseOptionalKeyword("at")) &&
+      failed(parseOptionalDefinedList(parser, state, crdUsedLvlSet, coords)))
+    return failure();
+
+  // Always use IndexType for the coordinate.
+  for (auto &coord : coords)
+    coord.type = parser.getBuilder().getIndexType();
+
+  // Set the CrdUsedLvl bitset.
+  state.addAttribute("crdUsedLvls",
+                     parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
+  return success();
+}
+
 static ParseResult
-parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
-                     SmallVectorImpl<OpAsmParser::Argument> &iterators,
-                     SmallVectorImpl<OpAsmParser::Argument> &iterArgs) {
+parseSparseIterateLoop(OpAsmParser &parser, OperationState &state,
+                       SmallVectorImpl<OpAsmParser::Argument> &iterators,
+                       SmallVectorImpl<OpAsmParser::Argument> &blockArgs) {
   SmallVector<OpAsmParser::UnresolvedOperand> spaces;
   SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
 
@@ -2148,37 +2220,14 @@ parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
         parser.getNameLoc(),
         "mismatch in number of sparse iterators and sparse spaces");
 
-  // Parse "at(%crd0, _, ...)"
-  LevelSet crdUsedLvlSet;
-  bool hasUsedCrds = succeeded(parser.parseOptionalKeyword("at"));
-  unsigned lvlCrdCnt = 0;
-  if (hasUsedCrds) {
-    ParseResult crdList = parser.parseCommaSeparatedList(
-        OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
-          if (parser.parseOptionalKeyword("_")) {
-            if (parser.parseArgument(iterArgs.emplace_back()))
-              return failure();
-            // Always use IndexType for the coordinate.
-            crdUsedLvlSet.set(lvlCrdCnt);
-            iterArgs.back().type = parser.getBuilder().getIndexType();
-          }
-          lvlCrdCnt += 1;
-          return success();
-        });
-    if (failed(crdList)) {
-      return parser.emitError(
-          parser.getNameLoc(),
-          "expecting SSA value or \"_\" for level coordinates");
-    }
-  }
-  // Set the CrdUsedLvl bitset.
-  state.addAttribute("crdUsedLvls",
-                     parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
+  if (failed(parseUsedCoordList(parser, state, blockArgs)))
+    return failure();
+  size_t numCrds = blockArgs.size();
 
   // Parse "iter_args(%arg = %init, ...)"
   bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
   if (hasIterArgs)
-    if (parser.parseAssignmentList(iterArgs, initArgs))
+    if (parser.parseAssignmentList(blockArgs, initArgs))
       return failure();
 
   SmallVector<Type> iterSpaceTps;
@@ -2196,10 +2245,6 @@ parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
       return parser.emitError(parser.getNameLoc(),
                               "expected sparse_tensor.iter_space type for "
                               "iteration space operands");
-    if (hasUsedCrds && spaceTp.getSpaceDim() != lvlCrdCnt)
-      return parser.emitError(parser.getNameLoc(),
-                              "mismatch in number of iteration space dimension "
-                              "and specified coordinates");
     it.type = spaceTp.getIteratorType();
   }
 
@@ -2213,9 +2258,68 @@ parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
     return failure();
 
   if (hasIterArgs) {
-    unsigned numCrds = crdUsedLvlSet.count();
     // Strip off leading args that used for coordinates.
-    MutableArrayRef args = MutableArrayRef(iterArgs).drop_front(numCrds);
+    MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds);
+    if (args.size() != initArgs.size() || args.size() != state.types.size()) {
+      return parser.emitError(
+          parser.getNameLoc(),
+          "mismatch in number of iteration arguments and return values");
+    }
+
+    for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
+      it.type = tp;
+      if (parser.resolveOperand(init, tp, state.operands))
+        return failure();
+    }
+  }
+  return success();
+}
+
+static ParseResult
+parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state,
+                         SmallVectorImpl<Value> &spacesVals,
+                         SmallVectorImpl<OpAsmParser::Argument> &blockArgs) {
+
+  // Parse "(%spaces, ...)"
+  SmallVector<OpAsmParser::UnresolvedOperand> spaces;
+  if (parser.parseOperandList(spaces, OpAsmParser::Delimiter::Paren))
+    return failure();
+
+  if (failed(parseUsedCoordList(parser, state, blockArgs)))
+    return failure();
+  size_t numCrds = blockArgs.size();
+
+  // Parse "iter_args(%arg = %init, ...)"
+  SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
+  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
+  if (hasIterArgs)
+    if (parser.parseAssignmentList(blockArgs, initArgs))
+      return failure();
+
+  SmallVector<Type> iterSpaceTps;
+  // parse ": (sparse_tensor.iter_space, ...) -> ret"
+  if (parser.parseColon() || parser.parseLParen() ||
+      parser.parseTypeList(iterSpaceTps) || parser.parseRParen())
+    return failure();
+
+  if (iterSpaceTps.size() != spaces.size())
+    return parser.emitError(parser.getNameLoc(),
+                            "mismatch in number of iteration space operands "
+                            "and iteration space types");
+
+  if (hasIterArgs)
+    if (parser.parseArrowTypeList(state.types))
+      return failure();
+
+  // Resolves input sparse iteration spaces.
+  if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
+                             spacesVals))
+    return failure();
+  state.operands.append(spacesVals);
+
+  if (hasIterArgs) {
+    // Strip off leading args that used for coordinates.
+    MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds);
     if (args.size() != initArgs.size() || args.size() != state.types.size()) {
       return parser.emitError(
           parser.getNameLoc(),
@@ -2272,7 +2376,7 @@ struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> {
 
   LogicalResult matchAndRewrite(IterateOp iterateOp,
                                 PatternRewriter &rewriter) const override {
-    LevelSet newUsedLvls(0);
+    I64BitSet newUsedLvls(0);
     llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
     for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) {
       if (auto crd = iterateOp.getLvlCrd(i)) {
@@ -2304,13 +2408,13 @@ void IterateOp::build(OpBuilder &builder, OperationState &odsState,
                       Value iterSpace, ValueRange initArgs) {
   unsigned rank = llvm::cast<IterSpaceType>(iterSpace.getType()).getSpaceDim();
   // All ones.
-  LevelSet set((1 << rank) - 1);
+  I64BitSet set((1 << rank) - 1);
   return build(builder, odsState, iterSpace, initArgs, set);
 }
 
 void IterateOp::build(OpBuilder &builder, OperationState &odsState,
      ...
[truncated]

Copy link
Contributor

@ingomueller-net ingomueller-net left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice progress!

It's still WIP, right? At least there's a TODO that looks like it'll be addressed as part of this PR, right?

Here and in previous PRs, a high-level summary in the PR description would have helped me to see the big picture faster.

@aartbik
Copy link
Contributor

aartbik commented Jul 31, 2024

Have a look at Ingo's feedback. Keeping the TODO is okay with me, since you plan to immediately follow up with a next PR on filling out those details.

@PeimingLiu PeimingLiu merged commit 785a24f into llvm:main Jul 31, 2024
4 of 6 checks passed
@PeimingLiu PeimingLiu deleted the sparse-coiterate branch July 31, 2024 22:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants