Skip to content

Commit a54930e

Browse files
author
Peiming Liu
authored
[mlir][sparse] allow YieldOp to yield multiple values. (#87261)
1 parent 2cfd7d4 commit a54930e

File tree

4 files changed

+29
-23
lines changed

4 files changed

+29
-23
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,8 +1278,10 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
12781278
let hasVerifier = 1;
12791279
}
12801280

1281-
def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator]>,
1282-
Arguments<(ins Optional<AnyType>:$result)> {
1281+
def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
1282+
ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
1283+
"ForeachOp"]>]>,
1284+
Arguments<(ins Variadic<AnyType>:$results)> {
12831285
let summary = "Yield from sparse_tensor set-like operations";
12841286
let description = [{
12851287
Yields a value from within a `binary`, `unary`, `reduce`,
@@ -1302,14 +1304,27 @@ def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator]>,
13021304
let builders = [
13031305
OpBuilder<(ins),
13041306
[{
1305-
build($_builder, $_state, Value());
1307+
build($_builder, $_state, ValueRange());
1308+
}]>,
1309+
OpBuilder<(ins "Value":$yieldVal),
1310+
[{
1311+
build($_builder, $_state, ValueRange(yieldVal));
13061312
}]>
13071313
];
13081314

1315+
let extraClassDeclaration = [{
1316+
Value getSingleResult() {
1317+
assert(hasSingleResult());
1318+
return getResults().front();
1319+
}
1320+
bool hasSingleResult() {
1321+
return getResults().size() == 1;
1322+
}
1323+
}];
1324+
13091325
let assemblyFormat = [{
1310-
$result attr-dict `:` type($result)
1326+
$results attr-dict `:` type($results)
13111327
}];
1312-
let hasVerifier = 1;
13131328
}
13141329

13151330
def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1591,7 +1591,8 @@ static LogicalResult verifyNumBlockArgs(T *op, Region &region,
15911591
if (!yield)
15921592
return op->emitError() << regionName
15931593
<< " region must end with sparse_tensor.yield";
1594-
if (!yield.getResult() || yield.getResult().getType() != outputType)
1594+
if (!yield.hasSingleResult() ||
1595+
yield.getSingleResult().getType() != outputType)
15951596
return op->emitError() << regionName << " region yield type mismatch";
15961597

15971598
return success();
@@ -1654,7 +1655,8 @@ LogicalResult UnaryOp::verify() {
16541655
// Absent branch can only yield invariant values.
16551656
Block *absentBlock = &absent.front();
16561657
Block *parent = getOperation()->getBlock();
1657-
Value absentVal = cast<YieldOp>(absentBlock->getTerminator()).getResult();
1658+
Value absentVal =
1659+
cast<YieldOp>(absentBlock->getTerminator()).getSingleResult();
16581660
if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
16591661
if (arg.getOwner() == parent)
16601662
return emitError("absent region cannot yield linalg argument");
@@ -1907,18 +1909,6 @@ LogicalResult SortOp::verify() {
19071909
return success();
19081910
}
19091911

1910-
LogicalResult YieldOp::verify() {
1911-
// Check for compatible parent.
1912-
auto *parentOp = (*this)->getParentOp();
1913-
if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
1914-
isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp) ||
1915-
isa<ForeachOp>(parentOp))
1916-
return success();
1917-
1918-
return emitOpError("expected parent op to be sparse_tensor unary, binary, "
1919-
"reduce, select or foreach");
1920-
}
1921-
19221912
/// Materialize a single constant operation from a given attribute value with
19231913
/// the desired resultant type.
19241914
Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,

mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -764,9 +764,10 @@ struct ForeachOpDemapper
764764
if (numInitArgs != 0) {
765765
rewriter.setInsertionPointToEnd(body);
766766
auto yield = llvm::cast<YieldOp>(body->getTerminator());
767-
if (auto stt = tryGetSparseTensorType(yield.getResult());
767+
if (auto stt = tryGetSparseTensorType(yield.getSingleResult());
768768
stt && !stt->isIdentity()) {
769-
Value y = genDemap(rewriter, stt->getEncoding(), yield.getResult());
769+
Value y =
770+
genDemap(rewriter, stt->getEncoding(), yield.getSingleResult());
770771
rewriter.create<YieldOp>(loc, y);
771772
rewriter.eraseOp(yield);
772773
}

mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,7 +1031,7 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
10311031
// invariant on the right.
10321032
Block &absentBlock = absentRegion.front();
10331033
YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
1034-
const Value absentVal = absentYield.getResult();
1034+
const Value absentVal = absentYield.getSingleResult();
10351035
const ExprId rhs = addInvariantExp(absentVal);
10361036
return disjSet(e, child0, buildLattices(rhs, i), unop);
10371037
}
@@ -1500,7 +1500,7 @@ static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region,
15001500
// Merge cloned block and return yield value.
15011501
Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0);
15021502
rewriter.inlineBlockBefore(&tmpRegion.front(), placeholder, vals);
1503-
Value val = clonedYield.getResult();
1503+
Value val = clonedYield.getSingleResult();
15041504
rewriter.eraseOp(clonedYield);
15051505
rewriter.eraseOp(placeholder);
15061506
return val;

0 commit comments

Comments
 (0)