Skip to content

[mlir][sparse] allow YieldOp to yield multiple values. #87261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1278,8 +1278,10 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
let hasVerifier = 1;
}

def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator]>,
Arguments<(ins Optional<AnyType>:$result)> {
def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
"ForeachOp"]>]>,
Arguments<(ins Variadic<AnyType>:$results)> {
let summary = "Yield from sparse_tensor set-like operations";
let description = [{
Yields a value from within a `binary`, `unary`, `reduce`,
Expand All @@ -1302,14 +1304,27 @@ def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator]>,
let builders = [
OpBuilder<(ins),
[{
build($_builder, $_state, Value());
build($_builder, $_state, ValueRange());
}]>,
OpBuilder<(ins "Value":$yieldVal),
[{
build($_builder, $_state, ValueRange(yieldVal));
}]>
];

let extraClassDeclaration = [{
Value getSingleResult() {
assert(hasSingleResult());
return getResults().front();
}
bool hasSingleResult() {
return getResults().size() == 1;
}
}];

let assemblyFormat = [{
$result attr-dict `:` type($result)
$results attr-dict `:` type($results)
}];
let hasVerifier = 1;
}

def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
Expand Down
18 changes: 4 additions & 14 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1591,7 +1591,8 @@ static LogicalResult verifyNumBlockArgs(T *op, Region &region,
if (!yield)
return op->emitError() << regionName
<< " region must end with sparse_tensor.yield";
if (!yield.getResult() || yield.getResult().getType() != outputType)
if (!yield.hasSingleResult() ||
yield.getSingleResult().getType() != outputType)
return op->emitError() << regionName << " region yield type mismatch";

return success();
Expand Down Expand Up @@ -1654,7 +1655,8 @@ LogicalResult UnaryOp::verify() {
// Absent branch can only yield invariant values.
Block *absentBlock = &absent.front();
Block *parent = getOperation()->getBlock();
Value absentVal = cast<YieldOp>(absentBlock->getTerminator()).getResult();
Value absentVal =
cast<YieldOp>(absentBlock->getTerminator()).getSingleResult();
if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
if (arg.getOwner() == parent)
return emitError("absent region cannot yield linalg argument");
Expand Down Expand Up @@ -1907,18 +1909,6 @@ LogicalResult SortOp::verify() {
return success();
}

LogicalResult YieldOp::verify() {
// Check for compatible parent.
auto *parentOp = (*this)->getParentOp();
if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp) ||
isa<ForeachOp>(parentOp))
return success();

return emitOpError("expected parent op to be sparse_tensor unary, binary, "
"reduce, select or foreach");
}

/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -764,9 +764,10 @@ struct ForeachOpDemapper
if (numInitArgs != 0) {
rewriter.setInsertionPointToEnd(body);
auto yield = llvm::cast<YieldOp>(body->getTerminator());
if (auto stt = tryGetSparseTensorType(yield.getResult());
if (auto stt = tryGetSparseTensorType(yield.getSingleResult());
stt && !stt->isIdentity()) {
Value y = genDemap(rewriter, stt->getEncoding(), yield.getResult());
Value y =
genDemap(rewriter, stt->getEncoding(), yield.getSingleResult());
rewriter.create<YieldOp>(loc, y);
rewriter.eraseOp(yield);
}
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1031,7 +1031,7 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
// invariant on the right.
Block &absentBlock = absentRegion.front();
YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
const Value absentVal = absentYield.getResult();
const Value absentVal = absentYield.getSingleResult();
const ExprId rhs = addInvariantExp(absentVal);
return disjSet(e, child0, buildLattices(rhs, i), unop);
}
Expand Down Expand Up @@ -1500,7 +1500,7 @@ static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region,
// Merge cloned block and return yield value.
Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0);
rewriter.inlineBlockBefore(&tmpRegion.front(), placeholder, vals);
Value val = clonedYield.getResult();
Value val = clonedYield.getSingleResult();
rewriter.eraseOp(clonedYield);
rewriter.eraseOp(placeholder);
return val;
Expand Down