Skip to content

[mlir][sparse] allow YieldOp to yield multiple values. #87261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 1, 2024

Conversation

PeimingLiu
Copy link
Member

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Apr 1, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sparse

Author: Peiming Liu (PeimingLiu)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/87261.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td (+17-5)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+4-14)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp (+3-2)
  • (modified) mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (+2-2)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 29cf8c32447ecf..29e5ac749e9fa3 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1278,8 +1278,10 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
   let hasVerifier = 1;
 }
 
-def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator]>,
-    Arguments<(ins Optional<AnyType>:$result)> {
+def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
+    ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
+                 "ForeachOp"]>]>,
+    Arguments<(ins Variadic<AnyType>:$results)> {
   let summary = "Yield from sparse_tensor set-like operations";
   let description = [{
       Yields a value from within a `binary`, `unary`, `reduce`,
@@ -1302,14 +1304,24 @@ def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator]>,
   let builders = [
     OpBuilder<(ins),
     [{
-      build($_builder, $_state, Value());
+      build($_builder, $_state, ValueRange());
+    }]>,
+    OpBuilder<(ins "Value":$yieldVal),
+    [{
+      build($_builder, $_state, ValueRange(yieldVal));
     }]>
   ];
 
+  let extraClassDeclaration = [{
+     Value getSingleResult() {
+        assert(getResults().size() == 1);
+        return getResults().front();
+     }
+  }];
+
   let assemblyFormat = [{
-        $result attr-dict `:` type($result)
+        $results attr-dict `:` type($results)
   }];
-  let hasVerifier = 1;
 }
 
 def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 6da51bb6b9cacf..3f385d24daf23f 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1591,7 +1591,8 @@ static LogicalResult verifyNumBlockArgs(T *op, Region &region,
   if (!yield)
     return op->emitError() << regionName
                            << " region must end with sparse_tensor.yield";
-  if (!yield.getResult() || yield.getResult().getType() != outputType)
+  if (!yield.getSingleResult() ||
+      yield.getSingleResult().getType() != outputType)
     return op->emitError() << regionName << " region yield type mismatch";
 
   return success();
@@ -1654,7 +1655,8 @@ LogicalResult UnaryOp::verify() {
     // Absent branch can only yield invariant values.
     Block *absentBlock = &absent.front();
     Block *parent = getOperation()->getBlock();
-    Value absentVal = cast<YieldOp>(absentBlock->getTerminator()).getResult();
+    Value absentVal =
+        cast<YieldOp>(absentBlock->getTerminator()).getSingleResult();
     if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
       if (arg.getOwner() == parent)
         return emitError("absent region cannot yield linalg argument");
@@ -1907,18 +1909,6 @@ LogicalResult SortOp::verify() {
   return success();
 }
 
-LogicalResult YieldOp::verify() {
-  // Check for compatible parent.
-  auto *parentOp = (*this)->getParentOp();
-  if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
-      isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp) ||
-      isa<ForeachOp>(parentOp))
-    return success();
-
-  return emitOpError("expected parent op to be sparse_tensor unary, binary, "
-                     "reduce, select or foreach");
-}
-
 /// Materialize a single constant operation from a given attribute value with
 /// the desired resultant type.
 Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 14ea07f0b54b82..9c0fc60877d8a3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -764,9 +764,10 @@ struct ForeachOpDemapper
     if (numInitArgs != 0) {
       rewriter.setInsertionPointToEnd(body);
       auto yield = llvm::cast<YieldOp>(body->getTerminator());
-      if (auto stt = tryGetSparseTensorType(yield.getResult());
+      if (auto stt = tryGetSparseTensorType(yield.getSingleResult());
           stt && !stt->isIdentity()) {
-        Value y = genDemap(rewriter, stt->getEncoding(), yield.getResult());
+        Value y =
+            genDemap(rewriter, stt->getEncoding(), yield.getSingleResult());
         rewriter.create<YieldOp>(loc, y);
         rewriter.eraseOp(yield);
       }
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 72b722c69ae34b..9c0aed3c18eff2 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -1031,7 +1031,7 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
       // invariant on the right.
       Block &absentBlock = absentRegion.front();
       YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
-      const Value absentVal = absentYield.getResult();
+      const Value absentVal = absentYield.getSingleResult();
       const ExprId rhs = addInvariantExp(absentVal);
       return disjSet(e, child0, buildLattices(rhs, i), unop);
     }
@@ -1500,7 +1500,7 @@ static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region,
   // Merge cloned block and return yield value.
   Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0);
   rewriter.inlineBlockBefore(&tmpRegion.front(), placeholder, vals);
-  Value val = clonedYield.getResult();
+  Value val = clonedYield.getSingleResult();
   rewriter.eraseOp(clonedYield);
   rewriter.eraseOp(placeholder);
   return val;

@PeimingLiu PeimingLiu merged commit a54930e into llvm:main Apr 1, 2024
@PeimingLiu PeimingLiu deleted the multi-yields branch April 1, 2024 17:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants