Skip to content

Commit 9e55292

Browse files
author
SJW
committed
[mlir] NamedAttribute utility generator
All attributes in MLIR are named, inherent attributes have unscoped names and discardable attributes should be scoped with a dialect. Current usage is ad-hoc and much of the codebase is sprinkled with constant strings used to lookup and set attributes, leading to potential bugs when names are not updated in all usages. This PR adds a tablegen'd utility wrapper for a NamedAttribute that manages scoped/unscoped name lookup for consistent typed access the attribute on an Operation.
1 parent 15617d1 commit 9e55292

File tree

12 files changed

+265
-24
lines changed

12 files changed

+265
-24
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,6 @@ def ROCDL_Dialect : Dialect {
2828
let hasOperationAttrVerify = 1;
2929

3030
let extraClassDeclaration = [{
31-
/// Get the name of the attribute used to annotate external kernel
32-
/// functions.
33-
static StringRef getKernelFuncAttrName() { return "rocdl.kernel"; }
34-
static constexpr ::llvm::StringLiteral getFlatWorkGroupSizeAttrName() {
35-
return ::llvm::StringLiteral("rocdl.flat_work_group_size");
36-
}
37-
static constexpr ::llvm::StringLiteral getReqdWorkGroupSizeAttrName() {
38-
return ::llvm::StringLiteral("rocdl.reqd_work_group_size");
39-
}
40-
4131
/// The address space value that represents global memory.
4232
static constexpr unsigned kGlobalMemoryAddressSpace = 1;
4333
/// The address space value that represents shared memory.
@@ -58,6 +48,22 @@ class ROCDL_Attr<string attrName, string attrMnemonic, list<Trait> traits = []>
5848
let mnemonic = attrMnemonic;
5949
}
6050

51+
//===----------------------------------------------------------------------===//
52+
// ROCDL named attr definitions
53+
//===----------------------------------------------------------------------===//
54+
55+
class ROCDL_NamedAttr<string name, string userName, string baseAttrType = "::mlir::Attribute"> :
56+
NamedAttrDef<ROCDL_Dialect, name, userName, baseAttrType>;
57+
58+
def ROCDL_KernelAttr : ROCDL_NamedAttr<"Kernel", "kernel", "::mlir::UnitAttr">;
59+
def ROCDL_ReqdWorkGroupSizeAttr :
60+
ROCDL_NamedAttr<"ReqdWorkGroupSize", "reqd_work_group_size", "::mlir::DenseI32ArrayAttr">;
61+
def ROCDL_FlatWorkGroupSizeAttr :
62+
ROCDL_NamedAttr<"FlatWorkGroupSize", "flat_work_group_size", "::mlir::StringAttr">;
63+
def ROCDL_MaxFlatWorkGroupSizeAttr :
64+
ROCDL_NamedAttr<"MaxFlatWorkGroupSize", "max_flat_work_group_size", "::mlir::IntegerAttr">;
65+
def ROCDL_WavesPerEuAttr :
66+
ROCDL_NamedAttr<"WavesPerEu", "waves_per_eu", "::mlir::IntegerAttr">;
6167

6268
//===----------------------------------------------------------------------===//
6369
// ROCDL op definitions

mlir/include/mlir/IR/AttrTypeBase.td

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,75 @@ class AttrDef<Dialect dialect, string name, list<Trait> traits = [],
283283
"::" # cppClassName # ">($_self)">;
284284
}
285285

286+
// Define a StringAttr wrapper for the NamedAttribute `name`
287+
// - `name` is dialect-qualified, but mnemonic is based
288+
// - Utilities to is/has/get/set/lookup/create typed Attr on an Operation
289+
// including typed `value` attribute
290+
class NamedAttrDef<Dialect dialect, string name, string userName,
291+
string valueAttrType = "::mlir::Attribute">
292+
: AttrDef<dialect, name, [], "::mlir::StringAttr"> {
293+
let mnemonic = userName;
294+
295+
string scopedName = dialect.name # "." # mnemonic;
296+
code getNameFunc = "static constexpr llvm::StringLiteral getName() { return \""
297+
# scopedName # "\"; }\n";
298+
code typedefValueAttr = "typedef " # valueAttrType # " ValueAttrType;\n";
299+
300+
code namedAttrDecls = !strconcat(typedefValueAttr, getNameFunc, [{
301+
// Is or Has
302+
static bool is(::mlir::NamedAttribute &attr) {
303+
return attr.getName() == getName() && ::llvm::isa<ValueAttrType>(attr.getValue());
304+
}
305+
static bool isInherent(::mlir::NamedAttribute &attr) {
306+
return attr.getName() == getMnemonic();
307+
}
308+
static bool has(::mlir::Operation *op) {
309+
return op->hasAttrOfType<ValueAttrType>(getName());
310+
}
311+
// Get Name
312+
static ::mlir::StringAttr get(::mlir::MLIRContext *ctx) {
313+
return ::mlir::StringAttr::get(ctx, getName());
314+
}
315+
// Get Value
316+
static ValueAttrType getValue(::mlir::Operation *op) {
317+
return op->getAttrOfType<ValueAttrType>(getName());
318+
}
319+
// Scoped lookup for inheritance
320+
static ValueAttrType lookupValue(::mlir::Operation *op) {
321+
if (auto attr = getValue(op))
322+
return attr;
323+
std::optional<::mlir::RegisteredOperationName> opInfo = op->getRegisteredInfo();
324+
if (!opInfo || !opInfo->hasTrait<::mlir::OpTrait::IsIsolatedFromAbove>()) {
325+
if (auto *par = op->getParentOp())
326+
return lookupValue(par);
327+
}
328+
return ValueAttrType();
329+
}
330+
// Set Value on Op
331+
static void setValue(::mlir::Operation *op, ValueAttrType val) {
332+
assert(op);
333+
op->setAttr(getName(), val);
334+
}
335+
// Remove Value from Op
336+
static void removeValue(::mlir::Operation *op) {
337+
assert(op);
338+
op->removeAttr(getName());
339+
}
340+
// Create (scoped) NamedAttribute
341+
static ::mlir::NamedAttribute create(::mlir::Builder &b, ValueAttrType val);
342+
}]);
343+
344+
code namedAttrDefs = [{
345+
// Create (scoped) NamedAttribute
346+
::mlir::NamedAttribute $cppClass::create(::mlir::Builder &b, $cppClass::ValueAttrType val) {
347+
return b.getNamedAttr($cppClass::getName(), val);
348+
}
349+
}];
350+
351+
let extraClassDeclaration = namedAttrDecls;
352+
let extraClassDefinition = namedAttrDefs;
353+
}
354+
286355
// Define a new type, named `name`, belonging to `dialect` that inherits from
287356
// the given C++ base class.
288357
class TypeDef<Dialect dialect, string name, list<Trait> traits = [],

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,7 @@ struct LowerGpuOpsToROCDLOpsPass
291291
m.walk([ctx](LLVM::LLVMFuncOp op) {
292292
if (auto blockSizes = dyn_cast_or_null<DenseI32ArrayAttr>(
293293
op->removeAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName()))) {
294-
op->setAttr(ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName(),
295-
blockSizes);
294+
ROCDL::ReqdWorkGroupSizeAttr::setValue(op, blockSizes);
296295
// Also set up the rocdl.flat_work_group_size attribute to prevent
297296
// conflicting metadata.
298297
uint32_t flatSize = 1;
@@ -301,8 +300,7 @@ struct LowerGpuOpsToROCDLOpsPass
301300
}
302301
StringAttr flatSizeAttr =
303302
StringAttr::get(ctx, Twine(flatSize) + "," + Twine(flatSize));
304-
op->setAttr(ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName(),
305-
flatSizeAttr);
303+
ROCDL::FlatWorkGroupSizeAttr::setValue(op, flatSizeAttr);
306304
}
307305
});
308306
}
@@ -355,8 +353,7 @@ void mlir::populateGpuToROCDLConversionPatterns(
355353
converter,
356354
/*allocaAddrSpace=*/ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace,
357355
/*workgroupAddrSpace=*/ROCDL::ROCDLDialect::kSharedMemoryAddressSpace,
358-
StringAttr::get(&converter.getContext(),
359-
ROCDL::ROCDLDialect::getKernelFuncAttrName()));
356+
ROCDL::KernelAttr::get(&converter.getContext()));
360357
if (Runtime::HIP == runtime) {
361358
patterns.add<GPUPrintfOpToHIPLowering>(converter);
362359
} else if (Runtime::OpenCL == runtime) {

mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,9 @@ void ROCDLDialect::initialize() {
253253
LogicalResult ROCDLDialect::verifyOperationAttribute(Operation *op,
254254
NamedAttribute attr) {
255255
// Kernel function attribute should be attached to functions.
256-
if (attr.getName() == ROCDLDialect::getKernelFuncAttrName()) {
256+
if (ROCDL::KernelAttr::is(attr)) {
257257
if (!isa<LLVM::LLVMFuncOp>(op)) {
258-
return op->emitError() << "'" << ROCDLDialect::getKernelFuncAttrName()
258+
return op->emitError() << "'" << ROCDL::KernelAttr::getName()
259259
<< "' attribute attached to unexpected op";
260260
}
261261
}

mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class ROCDLDialectLLVMIRTranslationInterface
8383
LogicalResult
8484
amendOperation(Operation *op, NamedAttribute attribute,
8585
LLVM::ModuleTranslation &moduleTranslation) const final {
86-
if (attribute.getName() == ROCDL::ROCDLDialect::getKernelFuncAttrName()) {
86+
if (ROCDL::KernelAttr::is(attribute)) {
8787
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
8888
if (!func)
8989
return failure();
@@ -105,7 +105,7 @@ class ROCDLDialectLLVMIRTranslationInterface
105105
// Override flat-work-group-size
106106
// TODO: update clients to rocdl.flat_work_group_size instead,
107107
// then remove this half of the branch
108-
if ("rocdl.max_flat_work_group_size" == attribute.getName()) {
108+
if (ROCDL::MaxFlatWorkGroupSizeAttr::is(attribute)) {
109109
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
110110
if (!func)
111111
return failure();
@@ -120,8 +120,7 @@ class ROCDLDialectLLVMIRTranslationInterface
120120
attrValueStream << "1," << value.getInt();
121121
llvmFunc->addFnAttr("amdgpu-flat-work-group-size", llvmAttrValue);
122122
}
123-
if (ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName() ==
124-
attribute.getName()) {
123+
if (ROCDL::FlatWorkGroupSizeAttr::is(attribute)) {
125124
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
126125
if (!func)
127126
return failure();
@@ -137,8 +136,7 @@ class ROCDLDialectLLVMIRTranslationInterface
137136
}
138137

139138
// Set reqd_work_group_size metadata
140-
if (ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName() ==
141-
attribute.getName()) {
139+
if (ROCDL::ReqdWorkGroupSizeAttr::is(attribute)) {
142140
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
143141
if (!func)
144142
return failure();

mlir/test/IR/test-named-attrs.mlir

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// RUN: mlir-opt %s -test-named-attrs -split-input-file --verify-diagnostics | FileCheck %s
2+
3+
func.func @f_unit_attr() attributes {test.test_unit} { // expected-remark {{found unit attr}}
4+
%0:2 = "test.producer"() : () -> (i32, i32)
5+
return
6+
}
7+
8+
// -----
9+
10+
func.func @f_unit_attr_fail() attributes {test.test_unit_fail} { // expected-error {{missing unit attr}}
11+
%0:2 = "test.producer"() : () -> (i32, i32)
12+
return
13+
}
14+
15+
// -----
16+
17+
func.func @f_int_attr() attributes {test.test_int = 42 : i32} { // expected-remark {{correct int value}}
18+
%0:2 = "test.producer"() : () -> (i32, i32)
19+
return
20+
}
21+
22+
// -----
23+
24+
func.func @f_int_attr_fail() attributes {test.test_int = 24 : i32} { // expected-error {{wrong int value}}
25+
%0:2 = "test.producer"() : () -> (i32, i32)
26+
return
27+
}
28+
29+
// -----
30+
31+
func.func @f_int_attr_fail2() attributes {test.test_int_fail = 42 : i64} { // expected-error {{missing int attr}}
32+
%0:2 = "test.producer"() : () -> (i32, i32)
33+
return
34+
}
35+
36+
// -----
37+
38+
func.func @f_lookup_attr() attributes {test.test_int = 42 : i64} { // expected-remark {{lookup found attr}}
39+
%0:2 = "test.producer"() : () -> (i32, i32) // expected-remark {{lookup found attr}}
40+
return // expected-remark {{lookup found attr}}
41+
}
42+
43+
// -----
44+
45+
func.func @f_lookup_attr2() { // expected-error {{lookup failed}}
46+
"test.any_attr_of_i32_str"() {attr = 3 : i32, test.test_int = 24 : i32} : () -> () // expected-remark {{lookup found attr}}
47+
return // expected-error {{lookup failed}}
48+
}
49+
50+
// -----
51+
52+
func.func @f_lookup_attr_fail() attributes {test.test_int_fail = 42 : i64} { // expected-error {{lookup failed}}
53+
%0:2 = "test.producer"() : () -> (i32, i32) // expected-error {{lookup failed}}
54+
return // expected-error {{lookup failed}}
55+
}
56+
57+
// -----
58+
59+
// CHECK: func.func @f_set_attr() attributes {test.test_int = 42 : i32}
60+
func.func @f_set_attr() { // expected-remark {{set int attr}}
61+
%0:2 = "test.producer"() : () -> (i32, i32)
62+
return
63+
}
64+

mlir/test/lib/Dialect/Test/TestAttrDefs.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,13 @@ def TestCopyCount : Test_Attr<"TestCopyCount"> {
333333
}
334334

335335

336+
// Test NamedAttr attributes
337+
class Test_NamedAttr<string name, string userName, string baseAttrType = "::mlir::Attribute"> :
338+
NamedAttrDef<Test_Dialect, name, userName, baseAttrType>;
339+
340+
def Test_NamedUnitAttr : Test_NamedAttr<"TestNamedUnit", "test_unit", "::mlir::UnitAttr">;
341+
def Test_NamedIntAttr : Test_NamedAttr<"TestNamedInt", "test_int", "::mlir::IntegerAttr">;
342+
336343

337344

338345
#endif // TEST_ATTRDEFS

mlir/test/lib/Dialect/Test/TestAttributes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ class CopyCount {
4343
llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
4444
const test::CopyCount &value);
4545

46+
#include "mlir/IR/Operation.h"
47+
4648
/// A handle used to reference external elements instances.
4749
using TestDialectResourceBlobHandle =
4850
mlir::DialectResourceBlobHandle<TestDialect>;

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ include "mlir/Dialect/DLTI/DLTIBase.td"
1515
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
1616
include "mlir/IR/EnumAttr.td"
1717
include "mlir/Interfaces/FunctionInterfaces.td"
18+
include "mlir/IR/AttrTypeBase.td"
1819
include "mlir/IR/OpBase.td"
1920
include "mlir/IR/OpAsmInterface.td"
2021
include "mlir/IR/PatternBase.td"

mlir/test/lib/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ add_mlir_library(MLIRTestIR
1010
TestInterfaces.cpp
1111
TestMatchers.cpp
1212
TestLazyLoading.cpp
13+
TestNamedAttrs.cpp
1314
TestOpaqueLoc.cpp
1415
TestOperationEquals.cpp
1516
TestPrintDefUse.cpp

mlir/test/lib/IR/TestNamedAttrs.cpp

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
//===- TestNamedAttrs.cpp - Test passes for MLIR types
2+
//-------------------------===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#include "TestAttributes.h"
11+
#include "TestDialect.h"
12+
#include "mlir/Pass/Pass.h"
13+
14+
using namespace mlir;
15+
using namespace test;
16+
17+
namespace {
18+
struct TestNamedAttrsPass
19+
: public PassWrapper<TestNamedAttrsPass, OperationPass<func::FuncOp>> {
20+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestNamedAttrsPass)
21+
22+
StringRef getArgument() const final { return "test-named-attrs"; }
23+
StringRef getDescription() const final {
24+
return "Test support for recursive types";
25+
}
26+
void runOnOperation() override {
27+
func::FuncOp func = getOperation();
28+
29+
auto funcName = func.getName();
30+
// Just make sure recursive types are printed and parsed.
31+
if (funcName.contains("f_unit_attr")) {
32+
if (test::TestNamedUnitAttr::has(func)) {
33+
func.emitRemark() << "found unit attr";
34+
} else {
35+
func.emitOpError() << "missing unit attr";
36+
signalPassFailure();
37+
}
38+
return;
39+
}
40+
41+
if (funcName.contains("f_int_attr")) {
42+
if (test::TestNamedIntAttr::has(func)) {
43+
if (test::TestNamedIntAttr::getValue(func).getInt() == 42) {
44+
func.emitRemark() << "correct int value";
45+
} else {
46+
func.emitOpError() << "wrong int value";
47+
signalPassFailure();
48+
}
49+
return;
50+
} else {
51+
func.emitOpError() << "missing int attr";
52+
signalPassFailure();
53+
}
54+
return;
55+
}
56+
57+
if (funcName.contains("f_lookup_attr")) {
58+
func.walk([&](Operation *op) {
59+
if (test::TestNamedIntAttr::lookupValue(op)) {
60+
op->emitRemark() << "lookup found attr";
61+
} else {
62+
op->emitOpError() << "lookup failed";
63+
signalPassFailure();
64+
}
65+
});
66+
return;
67+
}
68+
69+
if (funcName.contains("f_set_attr")) {
70+
if (!test::TestNamedIntAttr::has(func)) {
71+
auto intTy = IntegerType::get(func.getContext(), 32);
72+
test::TestNamedIntAttr::setValue(func, IntegerAttr::get(intTy, 42));
73+
func.emitRemark() << "set int attr";
74+
} else {
75+
func.emitOpError() << "attr already set";
76+
signalPassFailure();
77+
}
78+
return;
79+
}
80+
81+
// Unknown key.
82+
func.emitOpError() << "unexpected function name";
83+
signalPassFailure();
84+
}
85+
};
86+
} // namespace
87+
88+
namespace mlir {
89+
namespace test {
90+
91+
void registerTestNamedAttrsPass() { PassRegistration<TestNamedAttrsPass>(); }
92+
93+
} // namespace test
94+
} // namespace mlir

0 commit comments

Comments
 (0)