Skip to content

Commit 7f9e9c7

Browse files
committed
Move getAsmBlockArgumentNames from OpAsmDialectInterface to OpAsmOpInterface
This method is more suitable as an opinterface: it seems intrinsic to individual instances of the operation instead of the dialect. Also remove the restriction on the interface being applicable to the entry block only. Differential Revision: https://reviews.llvm.org/D116018
1 parent 9c11e95 commit 7f9e9c7

File tree

5 files changed

+45
-37
lines changed

5 files changed

+45
-37
lines changed

mlir/include/mlir/IR/OpAsmInterface.td

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,18 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
4949
}],
5050
"void", "getAsmResultNames",
5151
(ins "::mlir::OpAsmSetValueNameFn":$setNameFn),
52-
"", ";"
52+
"", "return;"
53+
>,
54+
InterfaceMethod<[{
55+
Get a special name to use when printing the block arguments for a region
56+
immediately nested under this operation.
57+
}],
58+
"void", "getAsmBlockArgumentNames",
59+
(ins
60+
"::mlir::Region&":$region,
61+
"::mlir::OpAsmSetValueNameFn":$setNameFn
62+
),
63+
"", "return;"
5364
>,
5465
StaticInterfaceMethod<[{
5566
Return the default dialect used when printing/parsing operations in

mlir/include/mlir/IR/OpImplementation.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,11 +1348,6 @@ class OpAsmDialectInterface
13481348
/// OpAsmInterface.td#getAsmResultNames for usage details and documentation.
13491349
virtual void getAsmResultNames(Operation *op,
13501350
OpAsmSetValueNameFn setNameFn) const {}
1351-
1352-
/// Get a special name to use when printing the entry block arguments of the
1353-
/// region contained by an operation in this dialect.
1354-
virtual void getAsmBlockArgumentNames(Block *block,
1355-
OpAsmSetValueNameFn setNameFn) const {}
13561351
};
13571352
} // namespace mlir
13581353

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,6 +1006,20 @@ void SSANameState::shadowRegionArgs(Region &region, ValueRange namesToUse) {
10061006
}
10071007

10081008
void SSANameState::numberValuesInRegion(Region &region) {
1009+
auto setBlockArgNameFn = [&](Value arg, StringRef name) {
1010+
assert(!valueIDs.count(arg) && "arg numbered multiple times");
1011+
assert(arg.cast<BlockArgument>().getOwner()->getParent() == &region &&
1012+
"arg not defined in current region");
1013+
setValueName(arg, name);
1014+
};
1015+
1016+
if (!printerFlags.shouldPrintGenericOpForm()) {
1017+
if (Operation *op = region.getParentOp()) {
1018+
if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
1019+
asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
1020+
}
1021+
}
1022+
10091023
// Number the values within this region in a breadth-first order.
10101024
unsigned nextBlockID = 0;
10111025
for (auto &block : region) {
@@ -1017,23 +1031,9 @@ void SSANameState::numberValuesInRegion(Region &region) {
10171031
}
10181032

10191033
void SSANameState::numberValuesInBlock(Block &block) {
1020-
auto setArgNameFn = [&](Value arg, StringRef name) {
1021-
assert(!valueIDs.count(arg) && "arg numbered multiple times");
1022-
assert(arg.cast<BlockArgument>().getOwner() == &block &&
1023-
"arg not defined in 'block'");
1024-
setValueName(arg, name);
1025-
};
1026-
1027-
bool isEntryBlock = block.isEntryBlock();
1028-
if (isEntryBlock && !printerFlags.shouldPrintGenericOpForm()) {
1029-
if (auto *op = block.getParentOp()) {
1030-
if (auto asmInterface = interfaces.getInterfaceFor(op->getDialect()))
1031-
asmInterface->getAsmBlockArgumentNames(&block, setArgNameFn);
1032-
}
1033-
}
1034-
10351034
// Number the block arguments. We give entry block arguments a special name
10361035
// 'arg'.
1036+
bool isEntryBlock = block.isEntryBlock();
10371037
SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
10381038
llvm::raw_svector_ostream specialName(specialNameBuffer);
10391039
for (auto arg : block.getArguments()) {

mlir/test/lib/Dialect/Test/TestDialect.cpp

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -105,20 +105,6 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
105105
if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
106106
setNameFn(asmOp, "result");
107107
}
108-
109-
void getAsmBlockArgumentNames(Block *block,
110-
OpAsmSetValueNameFn setNameFn) const final {
111-
auto op = block->getParentOp();
112-
auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
113-
if (!arrayAttr)
114-
return;
115-
auto args = block->getArguments();
116-
auto e = std::min(arrayAttr.size(), args.size());
117-
for (unsigned i = 0; i < e; ++i) {
118-
if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
119-
setNameFn(args[i], strAttr.getValue());
120-
}
121-
}
122108
};
123109

124110
struct TestDialectFoldInterface : public DialectFoldInterface {
@@ -848,6 +834,19 @@ static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
848834
return parser.parseRegion(*body, ivsInfo, argTypes);
849835
}
850836

837+
void PolyForOp::getAsmBlockArgumentNames(Region &region,
838+
OpAsmSetValueNameFn setNameFn) {
839+
auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names");
840+
if (!arrayAttr)
841+
return;
842+
auto args = getRegion().front().getArguments();
843+
auto e = std::min(arrayAttr.size(), args.size());
844+
for (unsigned i = 0; i < e; ++i) {
845+
if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
846+
setNameFn(args[i], strAttr.getValue());
847+
}
848+
}
849+
851850
//===----------------------------------------------------------------------===//
852851
// Test removing op with inner ops.
853852
//===----------------------------------------------------------------------===//

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1667,13 +1667,16 @@ def PrettyPrintedRegionOp : TEST_Op<"pretty_printed_region",
16671667
let printer = [{ return ::print(p, *this); }];
16681668
}
16691669

1670-
def PolyForOp : TEST_Op<"polyfor">
1670+
def PolyForOp : TEST_Op<"polyfor", [OpAsmOpInterface]>
16711671
{
16721672
let summary = "polyfor operation";
16731673
let description = [{
16741674
Test op with multiple region arguments, each argument of index type.
16751675
}];
1676-
1676+
let extraClassDeclaration = [{
1677+
void getAsmBlockArgumentNames(mlir::Region &region,
1678+
mlir::OpAsmSetValueNameFn setNameFn);
1679+
}];
16771680
let regions = (region SizedRegion<1>:$region);
16781681
let parser = [{ return ::parse$cppClass(parser, result); }];
16791682
}

0 commit comments

Comments
 (0)