Skip to content

Commit ad89eb5

Browse files
committed
Revert "Revert "[mlir][linalg] Add nicer builders for map and reduce.""
This reverts commit 7eef3ea.
1 parent acdd576 commit ad89eb5

File tree

2 files changed

+98
-53
lines changed

2 files changed

+98
-53
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,12 @@ def MapOp : LinalgStructuredBase_Op<"map", [
267267
let results = (outs Variadic<AnyTensor>:$result);
268268
let regions = (region SizedRegion<1>:$mapper);
269269

270+
let builders = [
271+
OpBuilder<(ins "ValueRange":$inputs, "Value":$init,
272+
"function_ref<void(OpBuilder &, Location, ValueRange)>",
273+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
274+
];
275+
270276
let extraClassDeclaration = structuredOpsBaseDecls # [{
271277
// Implement functions necessary for LinalgStructuredInterface.
272278
SmallVector<StringRef> getIteratorTypesArray();
@@ -341,6 +347,13 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
341347
let results = (outs Variadic<AnyTensor>);
342348
let regions = (region SizedRegion<1>:$combiner);
343349

350+
let builders = [
351+
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$inits,
352+
"ArrayRef<int64_t>":$dimensions,
353+
"function_ref<void(OpBuilder &, Location, ValueRange)>",
354+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
355+
];
356+
344357
let extraClassDeclaration = structuredOpsBaseDecls # [{
345358
// Declare functions necessary for LinalgStructuredInterface.
346359
SmallVector<StringRef> getIteratorTypesArray();

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 85 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,26 @@ void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
661661
// GenericOp
662662
//===----------------------------------------------------------------------===//
663663

664+
static void buildGenericRegion(
665+
OpBuilder &builder, OperationState &result, ValueRange inputs,
666+
ValueRange outputs,
667+
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
668+
SmallVector<Type, 4> blockArgTypes;
669+
SmallVector<Location, 4> blockArgLocs;
670+
for (ValueRange container : {inputs, outputs}) {
671+
for (Value v : container) {
672+
blockArgTypes.push_back(getElementTypeOrSelf(v));
673+
blockArgLocs.push_back(v.getLoc());
674+
}
675+
}
676+
677+
OpBuilder::InsertionGuard guard(builder);
678+
auto &region = *result.regions.front();
679+
Block *bodyBlock =
680+
builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
681+
bodyBuild(builder, result.location, bodyBlock->getArguments());
682+
}
683+
664684
void GenericOp::getAsmBlockArgumentNames(Region &region,
665685
OpAsmSetValueNameFn setNameFn) {
666686
for (Value v : getRegionInputArgs())
@@ -678,23 +698,8 @@ void GenericOp::build(
678698
build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
679699
iteratorTypes, doc, libraryCall);
680700
result.addAttributes(attributes);
681-
if (!bodyBuild)
682-
return;
683-
684-
SmallVector<Type, 4> blockArgTypes;
685-
SmallVector<Location, 4> blockArgLocs;
686-
for (ValueRange container : {inputs, outputs}) {
687-
for (Value v : container) {
688-
blockArgTypes.push_back(getElementTypeOrSelf(v));
689-
blockArgLocs.push_back(v.getLoc());
690-
}
691-
}
692-
693-
OpBuilder::InsertionGuard guard(builder);
694-
auto &region = *result.regions.front();
695-
Block *bodyBlock =
696-
builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
697-
bodyBuild(builder, result.location, bodyBlock->getArguments());
701+
if (bodyBuild)
702+
buildGenericRegion(builder, result, inputs, outputs, bodyBuild);
698703
}
699704

700705
void GenericOp::build(
@@ -1329,6 +1334,22 @@ void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
13291334
setNameFn(getResults().front(), "mapped");
13301335
}
13311336

1337+
void MapOp::build(
1338+
OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1339+
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1340+
ArrayRef<NamedAttribute> attributes) {
1341+
build(builder, result, TypeRange{}, inputs, init);
1342+
result.addAttributes(attributes);
1343+
1344+
// Add output types for `RankedTensorType` output arguments.
1345+
Type initType = init.getType();
1346+
if (initType.isa<RankedTensorType>())
1347+
result.addTypes(initType);
1348+
1349+
if (bodyBuild)
1350+
buildGenericRegion(builder, result, inputs, /*outputs=*/{}, bodyBuild);
1351+
}
1352+
13321353
ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
13331354
if (parseDstStyleOp(parser, result))
13341355
return failure();
@@ -1436,6 +1457,25 @@ void ReduceOp::getAsmResultNames(
14361457
setNameFn(getResults().front(), "reduced");
14371458
}
14381459

1460+
void ReduceOp::build(
1461+
OpBuilder &builder, OperationState &result, ValueRange inputs,
1462+
ValueRange inits, ArrayRef<int64_t> dimensions,
1463+
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1464+
ArrayRef<NamedAttribute> attributes) {
1465+
build(builder, result, TypeRange{}, inputs, inits, dimensions);
1466+
result.addAttributes(attributes);
1467+
1468+
// Add output types for `RankedTensorType` output arguments.
1469+
for (Value init : inits) {
1470+
Type initType = init.getType();
1471+
if (initType.isa<RankedTensorType>())
1472+
result.addTypes(initType);
1473+
}
1474+
1475+
if (bodyBuild)
1476+
buildGenericRegion(builder, result, inputs, inits, bodyBuild);
1477+
}
1478+
14391479
SmallVector<StringRef> ReduceOp::getIteratorTypesArray() {
14401480
int64_t inputRank = getInputs()[0].getType().cast<ShapedType>().getRank();
14411481
SmallVector<StringRef> iteratorTypes(inputRank,
@@ -1618,45 +1658,32 @@ TransposeOp::getRegionBuilder() {
16181658
};
16191659
}
16201660

1621-
void TransposeOp::createRegion(::mlir::OpBuilder &opBuilder,
1622-
::mlir::OperationState &odsState) {
1623-
Region *region = odsState.addRegion();
1624-
1625-
SmallVector<Type> argTypes;
1626-
SmallVector<Location> argLocs;
1627-
for (auto t : odsState.operands) {
1628-
argTypes.push_back(getElementTypeOrSelf(t));
1629-
argLocs.push_back(opBuilder.getUnknownLoc());
1630-
}
1631-
1632-
// RAII.
1633-
OpBuilder::InsertionGuard guard(opBuilder);
1634-
Block *body =
1635-
opBuilder.createBlock(region, /*insertPt=*/{}, argTypes, argLocs);
1636-
1637-
ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
1638-
getRegionBuilder()(b, *body, odsState.attributes.getAttrs());
1639-
}
1640-
1641-
void TransposeOp::build(::mlir::OpBuilder &odsBuilder,
1642-
::mlir::OperationState &odsState, Value input,
1643-
Value init, DenseI64ArrayAttr permutation,
1661+
void TransposeOp::build(::mlir::OpBuilder &builder,
1662+
::mlir::OperationState &result, Value input, Value init,
1663+
DenseI64ArrayAttr permutation,
16441664
ArrayRef<NamedAttribute> attributes) {
1645-
odsState.addOperands(input);
1646-
odsState.addOperands(init);
1647-
odsState.addAttribute(getPermutationAttrName(odsState.name), permutation);
1648-
odsState.addAttributes(attributes);
1649-
odsState.addTypes(init.getType());
1665+
result.addOperands(input);
1666+
result.addOperands(init);
1667+
result.addAttribute(getPermutationAttrName(result.name), permutation);
1668+
result.addAttributes(attributes);
1669+
1670+
// Add output types for `RankedTensorType` output arguments.
1671+
Type initType = init.getType();
1672+
if (initType.isa<RankedTensorType>())
1673+
result.addTypes(initType);
16501674

1651-
createRegion(odsBuilder, odsState);
1675+
buildGenericRegion(builder, result, input, init,
1676+
[&](OpBuilder &b, Location loc, ValueRange args) {
1677+
b.create<linalg::YieldOp>(loc, args[0]);
1678+
});
16521679
}
16531680

1654-
void TransposeOp::build(::mlir::OpBuilder &odsBuilder,
1655-
::mlir::OperationState &odsState, Value input,
1656-
Value init, ArrayRef<int64_t> permutation,
1681+
void TransposeOp::build(::mlir::OpBuilder &builder,
1682+
::mlir::OperationState &result, Value input, Value init,
1683+
ArrayRef<int64_t> permutation,
16571684
ArrayRef<NamedAttribute> attributes) {
1658-
build(odsBuilder, odsState, input, init,
1659-
odsBuilder.getDenseI64ArrayAttr(permutation), attributes);
1685+
build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
1686+
attributes);
16601687
}
16611688

16621689
ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -1666,8 +1693,13 @@ ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
16661693
})))
16671694
return failure();
16681695

1669-
OpBuilder opBuilder(parser.getContext());
1670-
createRegion(opBuilder, result);
1696+
(void)result.addRegion();
1697+
OpBuilder builder(parser.getContext());
1698+
buildGenericRegion(builder, result, /*inputs=*/result.operands,
1699+
/*outputs=*/{},
1700+
[&](OpBuilder &b, Location loc, ValueRange args) {
1701+
b.create<linalg::YieldOp>(loc, args[0]);
1702+
});
16711703
return success();
16721704
}
16731705

0 commit comments

Comments
 (0)