Skip to content

Commit 57fdcb1

Browse files
[mlir] Add getAlias for OpAsmTypeInterface
1 parent 95922d8 commit 57fdcb1

File tree

5 files changed

+50
-10
lines changed

5 files changed

+50
-10
lines changed

mlir/include/mlir/IR/OpAsmInterface.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,13 @@ def OpAsmTypeInterface : TypeInterface<"OpAsmTypeInterface"> {
127127
"void", "getAsmName",
128128
(ins "::mlir::OpAsmSetNameFn":$setNameFn), "", ";"
129129
>,
130+
InterfaceMethod<[{
131+
Get a name to use when generating an alias for this type.
132+
}],
133+
"::mlir::OpAsmDialectInterface::AliasResult", "getAlias",
134+
(ins "::llvm::raw_ostream&":$os), "",
135+
"return ::mlir::OpAsmDialectInterface::AliasResult::NoAlias;"
136+
>,
130137
];
131138
}
132139

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,15 +1159,31 @@ template <typename T>
11591159
void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias,
11601160
bool canBeDeferred) {
11611161
SmallString<32> nameBuffer;
1162-
for (const auto &interface : interfaces) {
1163-
OpAsmDialectInterface::AliasResult result =
1164-
interface.getAlias(symbol, aliasOS);
1165-
if (result == OpAsmDialectInterface::AliasResult::NoAlias)
1166-
continue;
1167-
nameBuffer = std::move(aliasBuffer);
1168-
assert(!nameBuffer.empty() && "expected valid alias name");
1169-
if (result == OpAsmDialectInterface::AliasResult::FinalAlias)
1170-
break;
1162+
1163+
OpAsmDialectInterface::AliasResult symbolInterfaceResult =
1164+
OpAsmDialectInterface::AliasResult::NoAlias;
1165+
if constexpr (std::is_base_of_v<Type, T>) {
1166+
if (auto symbolInterface = mlir::dyn_cast<OpAsmTypeInterface>(symbol)) {
1167+
symbolInterfaceResult = symbolInterface.getAlias(aliasOS);
1168+
if (symbolInterfaceResult !=
1169+
OpAsmDialectInterface::AliasResult::NoAlias) {
1170+
nameBuffer = std::move(aliasBuffer);
1171+
assert(!nameBuffer.empty() && "expected valid alias name");
1172+
}
1173+
}
1174+
}
1175+
1176+
if (symbolInterfaceResult != OpAsmDialectInterface::AliasResult::FinalAlias) {
1177+
for (const auto &interface : interfaces) {
1178+
OpAsmDialectInterface::AliasResult result =
1179+
interface.getAlias(symbol, aliasOS);
1180+
if (result == OpAsmDialectInterface::AliasResult::NoAlias)
1181+
continue;
1182+
nameBuffer = std::move(aliasBuffer);
1183+
assert(!nameBuffer.empty() && "expected valid alias name");
1184+
if (result == OpAsmDialectInterface::AliasResult::FinalAlias)
1185+
break;
1186+
}
11711187
}
11721188

11731189
if (nameBuffer.empty())

mlir/test/IR/op-asm-interface.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,14 @@ func.func @block_argument_name_from_op_asm_type_interface() {
2222
}
2323
return
2424
}
25+
26+
// -----
27+
28+
// CHECK: !op_asm_type_interface_type =
29+
!type = !test.op_asm_type_interface
30+
31+
func.func @alias_from_op_asm_type_interface() {
32+
// CHECK-LABEL: @alias_from_op_asm_type_interface
33+
%0 = "test.result_name_from_type"() : () -> !type
34+
return
35+
}

mlir/test/lib/Dialect/Test/TestTypeDefs.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def TestTypeVerification : Test_Type<"TestTypeVerification"> {
399399
}
400400

401401
def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface",
402-
[DeclareTypeInterfaceMethods<OpAsmTypeInterface, ["getAsmName"]>]> {
402+
[DeclareTypeInterfaceMethods<OpAsmTypeInterface, ["getAsmName", "getAlias"]>]> {
403403
let mnemonic = "op_asm_type_interface";
404404
}
405405

mlir/test/lib/Dialect/Test/TestTypes.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,3 +537,9 @@ void TestTypeOpAsmTypeInterfaceType::getAsmName(
537537
OpAsmSetNameFn setNameFn) const {
538538
setNameFn("op_asm_type_interface");
539539
}
540+
541+
::mlir::OpAsmDialectInterface::AliasResult
542+
TestTypeOpAsmTypeInterfaceType::getAlias(::llvm::raw_ostream &os) const {
543+
os << "op_asm_type_interface_type";
544+
return ::mlir::OpAsmDialectInterface::AliasResult::FinalAlias;
545+
}

0 commit comments

Comments
 (0)