-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] NamedAttribute utility generator #75118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -283,6 +283,75 @@ class AttrDef<Dialect dialect, string name, list<Trait> traits = [], | |
"::" # cppClassName # ">($_self)">; | ||
} | ||
|
||
// Define a StringAttr wrapper for the NamedAttribute `name` | ||
// - `name` is dialect-qualified, but mnemonic is based | ||
// - Utilities to is/has/get/set/lookup/create typed Attr on an Operation | ||
// including typed `value` attribute | ||
class NamedAttrDef<Dialect dialect, string name, string userName, | ||
string valueAttrType = "::mlir::Attribute"> | ||
: AttrDef<dialect, name, [], "::mlir::StringAttr"> { | ||
let mnemonic = userName; | ||
|
||
string scopedName = dialect.name # "." # mnemonic; | ||
code getNameFunc = "static constexpr llvm::StringLiteral getName() { return \"" | ||
# scopedName # "\"; }\n"; | ||
code typedefValueAttr = "typedef " # valueAttrType # " ValueAttrType;\n"; | ||
|
||
code namedAttrDecls = !strconcat(typedefValueAttr, getNameFunc, [{ | ||
// Is or Has | ||
static bool is(::mlir::NamedAttribute &attr) { | ||
return attr.getName() == getName() && ::llvm::isa<ValueAttrType>(attr.getValue()); | ||
} | ||
static bool isInherent(::mlir::NamedAttribute &attr) { | ||
return attr.getName() == getMnemonic(); | ||
} | ||
static bool has(::mlir::Operation *op) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The grammar of this method name feels weird. Would it make more sense as something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, to follow the rest of the interface it could be |
||
return op->hasAttrOfType<ValueAttrType>(getName()); | ||
} | ||
// Get Name | ||
static ::mlir::StringAttr get(::mlir::MLIRContext *ctx) { | ||
return ::mlir::StringAttr::get(ctx, getName()); | ||
} | ||
// Get Value | ||
static ValueAttrType getValue(::mlir::Operation *op) { | ||
return op->getAttrOfType<ValueAttrType>(getName()); | ||
} | ||
// Scoped lookup for inheritance | ||
static ValueAttrType lookupValue(::mlir::Operation *op) { | ||
if (auto attr = getValue(op)) | ||
return attr; | ||
std::optional<::mlir::RegisteredOperationName> opInfo = op->getRegisteredInfo(); | ||
if (!opInfo || !opInfo->hasTrait<::mlir::OpTrait::IsIsolatedFromAbove>()) { | ||
if (auto *par = op->getParentOp()) | ||
return lookupValue(par); | ||
} | ||
return ValueAttrType(); | ||
} | ||
// Set Value on Op | ||
static void setValue(::mlir::Operation *op, ValueAttrType val) { | ||
assert(op); | ||
op->setAttr(getName(), val); | ||
} | ||
// Remove Value from Op | ||
static void removeValue(::mlir::Operation *op) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The terminology is confusing: why are you using "Value" for referring to "Attributes"? |
||
assert(op); | ||
op->removeAttr(getName()); | ||
} | ||
// Create (scoped) NamedAttribute | ||
static ::mlir::NamedAttribute create(::mlir::Builder &b, ValueAttrType val); | ||
}]); | ||
|
||
code namedAttrDefs = [{ | ||
// Create (scoped) NamedAttribute | ||
::mlir::NamedAttribute $cppClass::create(::mlir::Builder &b, $cppClass::ValueAttrType val) { | ||
return b.getNamedAttr($cppClass::getName(), val); | ||
} | ||
}]; | ||
|
||
let extraClassDeclaration = namedAttrDecls; | ||
let extraClassDefinition = namedAttrDefs; | ||
} | ||
|
||
// Define a new type, named `name`, belonging to `dialect` that inherits from | ||
// the given C++ base class. | ||
class TypeDef<Dialect dialect, string name, list<Trait> traits = [], | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
// RUN: mlir-opt %s -test-named-attrs -split-input-file --verify-diagnostics | FileCheck %s | ||
|
||
func.func @f_unit_attr() attributes {test.test_unit} { // expected-remark {{found unit attr}} | ||
%0:2 = "test.producer"() : () -> (i32, i32) | ||
return | ||
} | ||
|
||
// ----- | ||
|
||
func.func @f_unit_attr_fail() attributes {test.test_unit_fail} { // expected-error {{missing unit attr}} | ||
%0:2 = "test.producer"() : () -> (i32, i32) | ||
return | ||
} | ||
|
||
// ----- | ||
|
||
func.func @f_int_attr() attributes {test.test_int = 42 : i32} { // expected-remark {{correct int value}} | ||
%0:2 = "test.producer"() : () -> (i32, i32) | ||
return | ||
} | ||
|
||
// ----- | ||
|
||
func.func @f_int_attr_fail() attributes {test.test_int = 24 : i32} { // expected-error {{wrong int value}} | ||
%0:2 = "test.producer"() : () -> (i32, i32) | ||
return | ||
} | ||
|
||
// ----- | ||
|
||
func.func @f_int_attr_fail2() attributes {test.test_int_fail = 42 : i64} { // expected-error {{missing int attr}} | ||
%0:2 = "test.producer"() : () -> (i32, i32) | ||
return | ||
} | ||
|
||
// ----- | ||
|
||
func.func @f_lookup_attr() attributes {test.test_int = 42 : i64} { // expected-remark {{lookup found attr}} | ||
%0:2 = "test.producer"() : () -> (i32, i32) // expected-remark {{lookup found attr}} | ||
return // expected-remark {{lookup found attr}} | ||
} | ||
|
||
// ----- | ||
|
||
func.func @f_lookup_attr2() { // expected-error {{lookup failed}} | ||
"test.any_attr_of_i32_str"() {attr = 3 : i32, test.test_int = 24 : i32} : () -> () // expected-remark {{lookup found attr}} | ||
return // expected-error {{lookup failed}} | ||
} | ||
|
||
// ----- | ||
|
||
func.func @f_lookup_attr_fail() attributes {test.test_int_fail = 42 : i64} { // expected-error {{lookup failed}} | ||
%0:2 = "test.producer"() : () -> (i32, i32) // expected-error {{lookup failed}} | ||
return // expected-error {{lookup failed}} | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK: func.func @f_set_attr() attributes {test.test_int = 42 : i32} | ||
func.func @f_set_attr() { // expected-remark {{set int attr}} | ||
%0:2 = "test.producer"() : () -> (i32, i32) | ||
return | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
//===- TestNamedAttrs.cpp - Test passes for MLIR types | ||
//-------------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "TestAttributes.h" | ||
#include "TestDialect.h" | ||
#include "mlir/Pass/Pass.h" | ||
|
||
using namespace mlir; | ||
using namespace test; | ||
|
||
namespace { | ||
struct TestNamedAttrsPass | ||
: public PassWrapper<TestNamedAttrsPass, OperationPass<func::FuncOp>> { | ||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestNamedAttrsPass) | ||
|
||
StringRef getArgument() const final { return "test-named-attrs"; } | ||
StringRef getDescription() const final { | ||
return "Test support for recursive types"; | ||
} | ||
void runOnOperation() override { | ||
func::FuncOp func = getOperation(); | ||
|
||
auto funcName = func.getName(); | ||
// Just make sure recursive types are printed and parsed. | ||
if (funcName.contains("f_unit_attr")) { | ||
if (test::TestNamedUnitAttr::has(func)) { | ||
func.emitRemark() << "found unit attr"; | ||
} else { | ||
func.emitOpError() << "missing unit attr"; | ||
signalPassFailure(); | ||
} | ||
return; | ||
} | ||
|
||
if (funcName.contains("f_int_attr")) { | ||
if (test::TestNamedIntAttr::has(func)) { | ||
if (test::TestNamedIntAttr::getValue(func).getInt() == 42) { | ||
func.emitRemark() << "correct int value"; | ||
} else { | ||
func.emitOpError() << "wrong int value"; | ||
signalPassFailure(); | ||
} | ||
return; | ||
} else { | ||
func.emitOpError() << "missing int attr"; | ||
signalPassFailure(); | ||
} | ||
return; | ||
} | ||
|
||
if (funcName.contains("f_lookup_attr")) { | ||
func.walk([&](Operation *op) { | ||
if (test::TestNamedIntAttr::lookupValue(op)) { | ||
op->emitRemark() << "lookup found attr"; | ||
} else { | ||
op->emitOpError() << "lookup failed"; | ||
signalPassFailure(); | ||
} | ||
}); | ||
return; | ||
} | ||
|
||
if (funcName.contains("f_set_attr")) { | ||
if (!test::TestNamedIntAttr::has(func)) { | ||
auto intTy = IntegerType::get(func.getContext(), 32); | ||
test::TestNamedIntAttr::setValue(func, IntegerAttr::get(intTy, 42)); | ||
func.emitRemark() << "set int attr"; | ||
} else { | ||
func.emitOpError() << "attr already set"; | ||
signalPassFailure(); | ||
} | ||
return; | ||
} | ||
|
||
// Unknown key. | ||
func.emitOpError() << "unexpected function name"; | ||
signalPassFailure(); | ||
} | ||
}; | ||
} // namespace | ||
|
||
namespace mlir { | ||
namespace test { | ||
|
||
void registerTestNamedAttrsPass() { PassRegistration<TestNamedAttrsPass>(); } | ||
|
||
} // namespace test | ||
} // namespace mlir |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The "NamedAttribute" part isn't clear to me.
Also this is more than a "wrapper", it seems to me this is actually registering its own full fledge attribute?
I'm not sure I follow the definition, but I suspect this won't be uniqued with StringAttr.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is no different than a StringAttr, so the usual interface works, but the new
get
only takes MLIRContext*. And once it is created, there is no way to get back to the wrapper class but it is mostly static lookup methods anyway.