Skip to content

Commit b055e6d

Browse files
committed
Add a new interface method getAsmBlockName() on OpAsmOpInterface to control block names
This allows operations to control the block ids used by the printer in nested regions. Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D115849
1 parent 3571bdb commit b055e6d

File tree

5 files changed

+136
-27
lines changed

5 files changed

+136
-27
lines changed

mlir/include/mlir/IR/OpAsmInterface.td

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,36 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
6262
),
6363
"", "return;"
6464
>,
65+
InterfaceMethod<[{
66+
Get the name to use for a given block inside a region attached to this
67+
operation.
68+
69+
For example if this operation has multiple blocks:
70+
71+
```mlir
72+
some.op() ({
73+
^bb0:
74+
...
75+
^bb1:
76+
...
77+
})
78+
```
79+
80+
the method will be invoked on each of the blocks allowing the op to
81+
print:
82+
83+
```mlir
84+
some.op() ({
85+
^custom_foo_name:
86+
...
87+
^custom_bar_name:
88+
...
89+
})
90+
```
91+
}],
92+
"void", "getAsmBlockNames",
93+
(ins "::mlir::OpAsmSetBlockNameFn":$setNameFn), "", ";"
94+
>,
6595
StaticInterfaceMethod<[{
6696
Return the default dialect used when printing/parsing operations in
6797
regions nested under this operation. This allows for eliding the dialect

mlir/include/mlir/IR/OpImplementation.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,6 +1322,10 @@ class OpAsmParser : public AsmParser {
13221322
/// operation. See 'getAsmResultNames' below for more details.
13231323
using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
13241324

1325+
/// A functor used to set the name of blocks in regions directly nested under
1326+
/// an operation.
1327+
using OpAsmSetBlockNameFn = function_ref<void(Block *, StringRef)>;
1328+
13251329
class OpAsmDialectInterface
13261330
: public DialectInterface::Base<OpAsmDialectInterface> {
13271331
public:

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,13 @@ void AliasState::printAliases(raw_ostream &os, NewLineCounter &newLine,
791791
//===----------------------------------------------------------------------===//
792792

793793
namespace {
794+
/// Info about block printing: a number which is its position in the visitation
795+
/// order, and a name that is used to print reference to it, e.g. ^bb42.
796+
struct BlockInfo {
797+
int ordering;
798+
StringRef name;
799+
};
800+
794801
/// This class manages the state of SSA value names.
795802
class SSANameState {
796803
public:
@@ -808,8 +815,8 @@ class SSANameState {
808815
/// operation, or empty if none exist.
809816
ArrayRef<int> getOpResultGroups(Operation *op);
810817

811-
/// Get the ID for the given block.
812-
unsigned getBlockID(Block *block);
818+
/// Get the info for the given block.
819+
BlockInfo getBlockInfo(Block *block);
813820

814821
/// Renumber the arguments for the specified region to the same names as the
815822
/// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for
@@ -846,8 +853,9 @@ class SSANameState {
846853
/// value of this map are the result numbers that start a result group.
847854
DenseMap<Operation *, SmallVector<int, 1>> opResultGroups;
848855

849-
/// This is the block ID for each block in the current.
850-
DenseMap<Block *, unsigned> blockIDs;
856+
/// This maps blocks to there visitation number in the current region as well
857+
/// as the string representing their name.
858+
DenseMap<Block *, BlockInfo> blockNames;
851859

852860
/// This keeps track of all of the non-numeric names that are in flight,
853861
/// allowing us to check for duplicates.
@@ -967,9 +975,10 @@ ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) {
967975
return it == opResultGroups.end() ? ArrayRef<int>() : it->second;
968976
}
969977

970-
unsigned SSANameState::getBlockID(Block *block) {
971-
auto it = blockIDs.find(block);
972-
return it != blockIDs.end() ? it->second : NameSentinel;
978+
BlockInfo SSANameState::getBlockInfo(Block *block) {
979+
auto it = blockNames.find(block);
980+
BlockInfo invalidBlock{-1, "INVALIDBLOCK"};
981+
return it != blockNames.end() ? it->second : invalidBlock;
973982
}
974983

975984
void SSANameState::shadowRegionArgs(Region &region, ValueRange namesToUse) {
@@ -1021,7 +1030,16 @@ void SSANameState::numberValuesInRegion(Region &region) {
10211030
for (auto &block : region) {
10221031
// Each block gets a unique ID, and all of the operations within it get
10231032
// numbered as well.
1024-
blockIDs[&block] = nextBlockID++;
1033+
auto blockInfoIt = blockNames.insert({&block, {-1, ""}});
1034+
if (blockInfoIt.second) {
1035+
// This block hasn't been named through `getAsmBlockArgumentNames`, use
1036+
// default `^bbNNN` format.
1037+
std::string name;
1038+
llvm::raw_string_ostream(name) << "^bb" << nextBlockID;
1039+
blockInfoIt.first->second.name = StringRef(name).copy(usedNameAllocator);
1040+
}
1041+
blockInfoIt.first->second.ordering = nextBlockID++;
1042+
10251043
numberValuesInBlock(block);
10261044
}
10271045
}
@@ -1048,11 +1066,6 @@ void SSANameState::numberValuesInBlock(Block &block) {
10481066
}
10491067

10501068
void SSANameState::numberValuesInOp(Operation &op) {
1051-
unsigned numResults = op.getNumResults();
1052-
if (numResults == 0)
1053-
return;
1054-
Value resultBegin = op.getResult(0);
1055-
10561069
// Function used to set the special result names for the operation.
10571070
SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
10581071
auto setResultNameFn = [&](Value result, StringRef name) {
@@ -1064,11 +1077,34 @@ void SSANameState::numberValuesInOp(Operation &op) {
10641077
if (int resultNo = result.cast<OpResult>().getResultNumber())
10651078
resultGroups.push_back(resultNo);
10661079
};
1080+
// Operations can customize the printing of block names in OpAsmOpInterface.
1081+
auto setBlockNameFn = [&](Block *block, StringRef name) {
1082+
assert(block->getParentOp() == &op &&
1083+
"getAsmBlockArgumentNames callback invoked on a block not directly "
1084+
"nested under the current operation");
1085+
assert(!blockNames.count(block) && "block numbered multiple times");
1086+
SmallString<16> tmpBuffer{"^"};
1087+
name = sanitizeIdentifier(name, tmpBuffer);
1088+
if (name.data() != tmpBuffer.data()) {
1089+
tmpBuffer.append(name);
1090+
name = tmpBuffer.str();
1091+
}
1092+
name = name.copy(usedNameAllocator);
1093+
blockNames[block] = {-1, name};
1094+
};
1095+
10671096
if (!printerFlags.shouldPrintGenericOpForm()) {
1068-
if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op))
1097+
if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) {
1098+
asmInterface.getAsmBlockNames(setBlockNameFn);
10691099
asmInterface.getAsmResultNames(setResultNameFn);
1100+
}
10701101
}
10711102

1103+
unsigned numResults = op.getNumResults();
1104+
if (numResults == 0)
1105+
return;
1106+
Value resultBegin = op.getResult(0);
1107+
10721108
// If the first result wasn't numbered, give it a default number.
10731109
if (valueIDs.try_emplace(resultBegin, nextValueID).second)
10741110
++nextValueID;
@@ -2609,11 +2645,7 @@ void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
26092645
}
26102646

26112647
void OperationPrinter::printBlockName(Block *block) {
2612-
auto id = state->getSSANameState().getBlockID(block);
2613-
if (id != SSANameState::NameSentinel)
2614-
os << "^bb" << id;
2615-
else
2616-
os << "^INVALIDBLOCK";
2648+
os << state->getSSANameState().getBlockInfo(block).name;
26172649
}
26182650

26192651
void OperationPrinter::print(Block *block, bool printBlockArgs,
@@ -2647,18 +2679,18 @@ void OperationPrinter::print(Block *block, bool printBlockArgs,
26472679
os << " // pred: ";
26482680
printBlockName(pred);
26492681
} else {
2650-
// We want to print the predecessors in increasing numeric order, not in
2682+
// We want to print the predecessors in a stable order, not in
26512683
// whatever order the use-list is in, so gather and sort them.
2652-
SmallVector<std::pair<unsigned, Block *>, 4> predIDs;
2684+
SmallVector<BlockInfo, 4> predIDs;
26532685
for (auto *pred : block->getPredecessors())
2654-
predIDs.push_back({state->getSSANameState().getBlockID(pred), pred});
2655-
llvm::array_pod_sort(predIDs.begin(), predIDs.end());
2686+
predIDs.push_back(state->getSSANameState().getBlockInfo(pred));
2687+
llvm::sort(predIDs, [](BlockInfo lhs, BlockInfo rhs) {
2688+
return lhs.ordering < rhs.ordering;
2689+
});
26562690

26572691
os << " // " << predIDs.size() << " preds: ";
26582692

2659-
interleaveComma(predIDs, [&](std::pair<unsigned, Block *> pred) {
2660-
printBlockName(pred.second);
2661-
});
2693+
interleaveComma(predIDs, [&](BlockInfo pred) { os << pred.name; });
26622694
}
26632695
os << newLine;
26642696
}

mlir/test/IR/pretty_printed_region_op.mlir

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ func @pretty_printed_region_op(%arg0 : f32, %arg1 : f32) -> (f32) {
3636

3737
// -----
3838

39-
4039
func @pretty_printed_region_op_deferred_loc(%arg0 : f32, %arg1 : f32) -> (f32) {
4140
// CHECK-LOCATION: "test.pretty_printed_region"(%arg1, %arg0)
4241
// CHECK-LOCATION: ^bb0(%arg[[x:[0-9]+]]: f32 loc("foo"), %arg[[y:[0-9]+]]: f32 loc("foo")
@@ -47,3 +46,29 @@ func @pretty_printed_region_op_deferred_loc(%arg0 : f32, %arg1 : f32) -> (f32) {
4746
%res = test.pretty_printed_region %arg1, %arg0 start special.op end : (f32, f32) -> (f32) loc("foo")
4847
return %res : f32
4948
}
49+
50+
// -----
51+
52+
// This tests the behavior of custom block names:
53+
// operations like `test.block_names` can define custom names for blocks in
54+
// nested regions.
55+
// CHECK-CUSTOM-LABEL: func @block_names
56+
func @block_names(%bool : i1) {
57+
// CHECK: test.block_names
58+
test.block_names {
59+
// CHECK-CUSTOM: br ^foo1
60+
// CHECK-GENERIC: cf.br{{.*}}^bb1
61+
cf.br ^foo1
62+
// CHECK-CUSTOM: ^foo1:
63+
// CHECK-GENERIC: ^bb1:
64+
^foo1:
65+
// CHECK-CUSTOM: br ^foo2
66+
// CHECK-GENERIC: cf.br{{.*}}^bb2
67+
cf.br ^foo2
68+
// CHECK-CUSTOM: ^foo2:
69+
// CHECK-GENERIC: ^bb2:
70+
^foo2:
71+
"test.return"() : () -> ()
72+
}
73+
return
74+
}

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,24 @@ def DefaultDialectOp : TEST_Op<"default_dialect", [OpAsmOpInterface]> {
660660
let assemblyFormat = "regions attr-dict-with-keyword";
661661
}
662662

663+
// This is used to test the OpAsmOpInterface::getAsmBlockName() feature:
664+
// blocks nested in a region under this op will have a name defined by the
665+
// interface.
666+
def AsmBlockNameOp : TEST_Op<"block_names", [OpAsmOpInterface]> {
667+
let regions = (region AnyRegion:$body);
668+
let extraClassDeclaration = [{
669+
void getAsmBlockNames(mlir::OpAsmSetBlockNameFn setNameFn) {
670+
std::string name;
671+
int count = 0;
672+
for (::mlir::Block &block : getRegion().getBlocks()) {
673+
name = "foo" + std::to_string(count++);
674+
setNameFn(&block, name);
675+
}
676+
}
677+
}];
678+
let assemblyFormat = "regions attr-dict-with-keyword";
679+
}
680+
663681
// This operation requires its return type to have the trait 'TestTypeTrait'.
664682
def ResultTypeWithTraitOp : TEST_Op<"result_type_with_trait", []> {
665683
let results = (outs AnyType);

0 commit comments

Comments
 (0)