Skip to content

[mlir] NamedAttribute utility generator #75118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,6 @@ def ROCDL_Dialect : Dialect {
let hasOperationAttrVerify = 1;

let extraClassDeclaration = [{
/// Get the name of the attribute used to annotate external kernel
/// functions.
static StringRef getKernelFuncAttrName() { return "rocdl.kernel"; }
static constexpr ::llvm::StringLiteral getFlatWorkGroupSizeAttrName() {
return ::llvm::StringLiteral("rocdl.flat_work_group_size");
}
static constexpr ::llvm::StringLiteral getReqdWorkGroupSizeAttrName() {
return ::llvm::StringLiteral("rocdl.reqd_work_group_size");
}

/// The address space value that represents global memory.
static constexpr unsigned kGlobalMemoryAddressSpace = 1;
/// The address space value that represents shared memory.
Expand All @@ -58,6 +48,22 @@ class ROCDL_Attr<string attrName, string attrMnemonic, list<Trait> traits = []>
let mnemonic = attrMnemonic;
}

//===----------------------------------------------------------------------===//
// ROCDL named attr definitions
//===----------------------------------------------------------------------===//

class ROCDL_NamedAttr<string name, string userName, string baseAttrType = "::mlir::Attribute"> :
NamedAttrDef<ROCDL_Dialect, name, userName, baseAttrType>;

def ROCDL_KernelAttr : ROCDL_NamedAttr<"Kernel", "kernel", "::mlir::UnitAttr">;
def ROCDL_ReqdWorkGroupSizeAttr :
ROCDL_NamedAttr<"ReqdWorkGroupSize", "reqd_work_group_size", "::mlir::DenseI32ArrayAttr">;
def ROCDL_FlatWorkGroupSizeAttr :
ROCDL_NamedAttr<"FlatWorkGroupSize", "flat_work_group_size", "::mlir::StringAttr">;
def ROCDL_MaxFlatWorkGroupSizeAttr :
ROCDL_NamedAttr<"MaxFlatWorkGroupSize", "max_flat_work_group_size", "::mlir::IntegerAttr">;
def ROCDL_WavesPerEuAttr :
ROCDL_NamedAttr<"WavesPerEu", "waves_per_eu", "::mlir::IntegerAttr">;

//===----------------------------------------------------------------------===//
// ROCDL op definitions
Expand Down
69 changes: 69 additions & 0 deletions mlir/include/mlir/IR/AttrTypeBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,75 @@ class AttrDef<Dialect dialect, string name, list<Trait> traits = [],
"::" # cppClassName # ">($_self)">;
}

// Define a StringAttr wrapper for the NamedAttribute `name`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "NamedAttribute" part isn't clear to me.

Also this is more than a "wrapper", it seems to me this is actually registering its own full fledge attribute?
I'm not sure I follow the definition, but I suspect this won't be uniqued with StringAttr.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is no different than a StringAttr, so the usual interface works, but the new get only takes MLIRContext*. And once it is created, there is no way to get back to the wrapper class but it is mostly static lookup methods anyway.

// - `name` is dialect-qualified, but mnemonic is based
// - Utilities to is/has/get/set/lookup/create typed Attr on an Operation
// including typed `value` attribute
class NamedAttrDef<Dialect dialect, string name, string userName,
string valueAttrType = "::mlir::Attribute">
: AttrDef<dialect, name, [], "::mlir::StringAttr"> {
let mnemonic = userName;

string scopedName = dialect.name # "." # mnemonic;
code getNameFunc = "static constexpr llvm::StringLiteral getName() { return \""
# scopedName # "\"; }\n";
code typedefValueAttr = "typedef " # valueAttrType # " ValueAttrType;\n";

code namedAttrDecls = !strconcat(typedefValueAttr, getNameFunc, [{
// Is or Has
static bool is(::mlir::NamedAttribute &attr) {
return attr.getName() == getName() && ::llvm::isa<ValueAttrType>(attr.getValue());
}
static bool isInherent(::mlir::NamedAttribute &attr) {
return attr.getName() == getMnemonic();
}
static bool has(::mlir::Operation *op) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The grammar of this method name feels weird. Would it make more sense as something like .isPresent or .isSetOn or some such? Asking if an attribute has an operation feels backwards

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, to follow the rest of the interface it could be hasValue.

return op->hasAttrOfType<ValueAttrType>(getName());
}
// Get Name
static ::mlir::StringAttr get(::mlir::MLIRContext *ctx) {
return ::mlir::StringAttr::get(ctx, getName());
}
// Get Value
static ValueAttrType getValue(::mlir::Operation *op) {
return op->getAttrOfType<ValueAttrType>(getName());
}
// Scoped lookup for inheritance
static ValueAttrType lookupValue(::mlir::Operation *op) {
if (auto attr = getValue(op))
return attr;
std::optional<::mlir::RegisteredOperationName> opInfo = op->getRegisteredInfo();
if (!opInfo || !opInfo->hasTrait<::mlir::OpTrait::IsIsolatedFromAbove>()) {
if (auto *par = op->getParentOp())
return lookupValue(par);
}
return ValueAttrType();
}
// Set Value on Op
static void setValue(::mlir::Operation *op, ValueAttrType val) {
assert(op);
op->setAttr(getName(), val);
}
// Remove Value from Op
static void removeValue(::mlir::Operation *op) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The terminology is confusing: why are you using "Value" for referring to "Attributes"?

assert(op);
op->removeAttr(getName());
}
// Create (scoped) NamedAttribute
static ::mlir::NamedAttribute create(::mlir::Builder &b, ValueAttrType val);
}]);

code namedAttrDefs = [{
// Create (scoped) NamedAttribute
::mlir::NamedAttribute $cppClass::create(::mlir::Builder &b, $cppClass::ValueAttrType val) {
return b.getNamedAttr($cppClass::getName(), val);
}
}];

let extraClassDeclaration = namedAttrDecls;
let extraClassDefinition = namedAttrDefs;
}

// Define a new type, named `name`, belonging to `dialect` that inherits from
// the given C++ base class.
class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
Expand Down
9 changes: 3 additions & 6 deletions mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,7 @@ struct LowerGpuOpsToROCDLOpsPass
m.walk([ctx](LLVM::LLVMFuncOp op) {
if (auto blockSizes = dyn_cast_or_null<DenseI32ArrayAttr>(
op->removeAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName()))) {
op->setAttr(ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName(),
blockSizes);
ROCDL::ReqdWorkGroupSizeAttr::setValue(op, blockSizes);
// Also set up the rocdl.flat_work_group_size attribute to prevent
// conflicting metadata.
uint32_t flatSize = 1;
Expand All @@ -301,8 +300,7 @@ struct LowerGpuOpsToROCDLOpsPass
}
StringAttr flatSizeAttr =
StringAttr::get(ctx, Twine(flatSize) + "," + Twine(flatSize));
op->setAttr(ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName(),
flatSizeAttr);
ROCDL::FlatWorkGroupSizeAttr::setValue(op, flatSizeAttr);
}
});
}
Expand Down Expand Up @@ -355,8 +353,7 @@ void mlir::populateGpuToROCDLConversionPatterns(
converter,
/*allocaAddrSpace=*/ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace,
/*workgroupAddrSpace=*/ROCDL::ROCDLDialect::kSharedMemoryAddressSpace,
StringAttr::get(&converter.getContext(),
ROCDL::ROCDLDialect::getKernelFuncAttrName()));
ROCDL::KernelAttr::get(&converter.getContext()));
if (Runtime::HIP == runtime) {
patterns.add<GPUPrintfOpToHIPLowering>(converter);
} else if (Runtime::OpenCL == runtime) {
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,9 @@ void ROCDLDialect::initialize() {
LogicalResult ROCDLDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
// Kernel function attribute should be attached to functions.
if (attr.getName() == ROCDLDialect::getKernelFuncAttrName()) {
if (ROCDL::KernelAttr::is(attr)) {
if (!isa<LLVM::LLVMFuncOp>(op)) {
return op->emitError() << "'" << ROCDLDialect::getKernelFuncAttrName()
return op->emitError() << "'" << ROCDL::KernelAttr::getName()
<< "' attribute attached to unexpected op";
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class ROCDLDialectLLVMIRTranslationInterface
LogicalResult
amendOperation(Operation *op, NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const final {
if (attribute.getName() == ROCDL::ROCDLDialect::getKernelFuncAttrName()) {
if (ROCDL::KernelAttr::is(attribute)) {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func)
return failure();
Expand All @@ -105,7 +105,7 @@ class ROCDLDialectLLVMIRTranslationInterface
// Override flat-work-group-size
// TODO: update clients to rocdl.flat_work_group_size instead,
// then remove this half of the branch
if ("rocdl.max_flat_work_group_size" == attribute.getName()) {
if (ROCDL::MaxFlatWorkGroupSizeAttr::is(attribute)) {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func)
return failure();
Expand All @@ -120,8 +120,7 @@ class ROCDLDialectLLVMIRTranslationInterface
attrValueStream << "1," << value.getInt();
llvmFunc->addFnAttr("amdgpu-flat-work-group-size", llvmAttrValue);
}
if (ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName() ==
attribute.getName()) {
if (ROCDL::FlatWorkGroupSizeAttr::is(attribute)) {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func)
return failure();
Expand All @@ -137,8 +136,7 @@ class ROCDLDialectLLVMIRTranslationInterface
}

// Set reqd_work_group_size metadata
if (ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName() ==
attribute.getName()) {
if (ROCDL::ReqdWorkGroupSizeAttr::is(attribute)) {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func)
return failure();
Expand Down
64 changes: 64 additions & 0 deletions mlir/test/IR/test-named-attrs.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// RUN: mlir-opt %s -test-named-attrs -split-input-file --verify-diagnostics | FileCheck %s

func.func @f_unit_attr() attributes {test.test_unit} { // expected-remark {{found unit attr}}
%0:2 = "test.producer"() : () -> (i32, i32)
return
}

// -----

func.func @f_unit_attr_fail() attributes {test.test_unit_fail} { // expected-error {{missing unit attr}}
%0:2 = "test.producer"() : () -> (i32, i32)
return
}

// -----

func.func @f_int_attr() attributes {test.test_int = 42 : i32} { // expected-remark {{correct int value}}
%0:2 = "test.producer"() : () -> (i32, i32)
return
}

// -----

func.func @f_int_attr_fail() attributes {test.test_int = 24 : i32} { // expected-error {{wrong int value}}
%0:2 = "test.producer"() : () -> (i32, i32)
return
}

// -----

func.func @f_int_attr_fail2() attributes {test.test_int_fail = 42 : i64} { // expected-error {{missing int attr}}
%0:2 = "test.producer"() : () -> (i32, i32)
return
}

// -----

func.func @f_lookup_attr() attributes {test.test_int = 42 : i64} { // expected-remark {{lookup found attr}}
%0:2 = "test.producer"() : () -> (i32, i32) // expected-remark {{lookup found attr}}
return // expected-remark {{lookup found attr}}
}

// -----

func.func @f_lookup_attr2() { // expected-error {{lookup failed}}
"test.any_attr_of_i32_str"() {attr = 3 : i32, test.test_int = 24 : i32} : () -> () // expected-remark {{lookup found attr}}
return // expected-error {{lookup failed}}
}

// -----

func.func @f_lookup_attr_fail() attributes {test.test_int_fail = 42 : i64} { // expected-error {{lookup failed}}
%0:2 = "test.producer"() : () -> (i32, i32) // expected-error {{lookup failed}}
return // expected-error {{lookup failed}}
}

// -----

// CHECK: func.func @f_set_attr() attributes {test.test_int = 42 : i32}
func.func @f_set_attr() { // expected-remark {{set int attr}}
%0:2 = "test.producer"() : () -> (i32, i32)
return
}

7 changes: 7 additions & 0 deletions mlir/test/lib/Dialect/Test/TestAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,13 @@ def TestCopyCount : Test_Attr<"TestCopyCount"> {
}


// Test NamedAttr attributes
class Test_NamedAttr<string name, string userName, string baseAttrType = "::mlir::Attribute"> :
NamedAttrDef<Test_Dialect, name, userName, baseAttrType>;

def Test_NamedUnitAttr : Test_NamedAttr<"TestNamedUnit", "test_unit", "::mlir::UnitAttr">;
def Test_NamedIntAttr : Test_NamedAttr<"TestNamedInt", "test_int", "::mlir::IntegerAttr">;



#endif // TEST_ATTRDEFS
2 changes: 2 additions & 0 deletions mlir/test/lib/Dialect/Test/TestAttributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class CopyCount {
llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
const test::CopyCount &value);

#include "mlir/IR/Operation.h"

/// A handle used to reference external elements instances.
using TestDialectResourceBlobHandle =
mlir::DialectResourceBlobHandle<TestDialect>;
Expand Down
1 change: 1 addition & 0 deletions mlir/test/lib/Dialect/Test/TestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ include "mlir/Dialect/DLTI/DLTIBase.td"
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
include "mlir/IR/EnumAttr.td"
include "mlir/Interfaces/FunctionInterfaces.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/PatternBase.td"
Expand Down
1 change: 1 addition & 0 deletions mlir/test/lib/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ add_mlir_library(MLIRTestIR
TestInterfaces.cpp
TestMatchers.cpp
TestLazyLoading.cpp
TestNamedAttrs.cpp
TestOpaqueLoc.cpp
TestOperationEquals.cpp
TestPrintDefUse.cpp
Expand Down
94 changes: 94 additions & 0 deletions mlir/test/lib/IR/TestNamedAttrs.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
//===- TestNamedAttrs.cpp - Test passes for MLIR types
//-------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "TestAttributes.h"
#include "TestDialect.h"
#include "mlir/Pass/Pass.h"

using namespace mlir;
using namespace test;

namespace {
struct TestNamedAttrsPass
: public PassWrapper<TestNamedAttrsPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestNamedAttrsPass)

StringRef getArgument() const final { return "test-named-attrs"; }
StringRef getDescription() const final {
return "Test support for recursive types";
}
void runOnOperation() override {
func::FuncOp func = getOperation();

auto funcName = func.getName();
// Just make sure recursive types are printed and parsed.
if (funcName.contains("f_unit_attr")) {
if (test::TestNamedUnitAttr::has(func)) {
func.emitRemark() << "found unit attr";
} else {
func.emitOpError() << "missing unit attr";
signalPassFailure();
}
return;
}

if (funcName.contains("f_int_attr")) {
if (test::TestNamedIntAttr::has(func)) {
if (test::TestNamedIntAttr::getValue(func).getInt() == 42) {
func.emitRemark() << "correct int value";
} else {
func.emitOpError() << "wrong int value";
signalPassFailure();
}
return;
} else {
func.emitOpError() << "missing int attr";
signalPassFailure();
}
return;
}

if (funcName.contains("f_lookup_attr")) {
func.walk([&](Operation *op) {
if (test::TestNamedIntAttr::lookupValue(op)) {
op->emitRemark() << "lookup found attr";
} else {
op->emitOpError() << "lookup failed";
signalPassFailure();
}
});
return;
}

if (funcName.contains("f_set_attr")) {
if (!test::TestNamedIntAttr::has(func)) {
auto intTy = IntegerType::get(func.getContext(), 32);
test::TestNamedIntAttr::setValue(func, IntegerAttr::get(intTy, 42));
func.emitRemark() << "set int attr";
} else {
func.emitOpError() << "attr already set";
signalPassFailure();
}
return;
}

// Unknown key.
func.emitOpError() << "unexpected function name";
signalPassFailure();
}
};
} // namespace

namespace mlir {
namespace test {

void registerTestNamedAttrsPass() { PassRegistration<TestNamedAttrsPass>(); }

} // namespace test
} // namespace mlir
Loading