Skip to content

Commit f64d744

Browse files
committed
[MLIR] Add ODS support for generating helpers for dialect (discardable) attributes
WIP (missing docs)
1 parent 79e6231 commit f64d744

File tree

9 files changed

+126
-28
lines changed

9 files changed

+126
-28
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,6 @@ def ROCDL_Dialect : Dialect {
2828
let hasOperationAttrVerify = 1;
2929

3030
let extraClassDeclaration = [{
31-
/// Get the name of the attribute used to annotate external kernel
32-
/// functions.
33-
static StringRef getKernelFuncAttrName() { return "rocdl.kernel"; }
34-
static constexpr ::llvm::StringLiteral getFlatWorkGroupSizeAttrName() {
35-
return ::llvm::StringLiteral("rocdl.flat_work_group_size");
36-
}
37-
static constexpr ::llvm::StringLiteral getReqdWorkGroupSizeAttrName() {
38-
return ::llvm::StringLiteral("rocdl.reqd_work_group_size");
39-
}
40-
4131
/// The address space value that represents global memory.
4232
static constexpr unsigned kGlobalMemoryAddressSpace = 1;
4333
/// The address space value that represents shared memory.
@@ -46,6 +36,14 @@ def ROCDL_Dialect : Dialect {
4636
static constexpr unsigned kPrivateMemoryAddressSpace = 5;
4737
}];
4838

39+
let discardableAttrs = (ins
40+
"::mlir::UnitAttr":$kernel,
41+
"::mlir::DenseI32ArrayAttr":$reqd_work_group_size,
42+
"::mlir::StringAttr":$flat_work_group_size,
43+
"::mlir::IntegerAttr":$max_flat_work_group_size,
44+
"::mlir::IntegerAttr":$waves_per_eu
45+
);
46+
4947
let useDefaultAttributePrinterParser = 1;
5048
}
5149

mlir/include/mlir/IR/DialectBase.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ class Dialect {
3434
// pattern or interfaces.
3535
list<string> dependentDialects = [];
3636

37+
dag discardableAttrs = (ins);
38+
3739
// The C++ namespace that ops of this dialect should be placed into.
3840
//
3941
// By default, uses the name of the dialect as the only namespace. To avoid

mlir/include/mlir/TableGen/Dialect.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#define MLIR_TABLEGEN_DIALECT_H_
1515

1616
#include "mlir/Support/LLVM.h"
17+
#include "llvm/TableGen/Record.h"
18+
1719
#include <string>
1820
#include <vector>
1921

@@ -90,6 +92,10 @@ class Dialect {
9092
/// dialect.
9193
bool usePropertiesForAttributes() const;
9294

95+
llvm::DagInit *getDiscardableAttributes() const;
96+
97+
const llvm::Record *getDef() const { return def; }
98+
9399
// Returns whether two dialects are equal by checking the equality of the
94100
// underlying record.
95101
bool operator==(const Dialect &other) const;

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -285,14 +285,17 @@ struct LowerGpuOpsToROCDLOpsPass
285285
configureGpuToROCDLConversionLegality(target);
286286
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
287287
signalPassFailure();
288-
288+
auto *rocdlDialect = getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
289+
auto reqdWorkGroupSizeAttrHelper =
290+
rocdlDialect->getReqdWorkGroupSizeAttrHelper();
291+
auto flatWorkGroupSizeAttrHelper =
292+
rocdlDialect->getFlatWorkGroupSizeAttrHelper();
289293
// Manually rewrite known block size attributes so the LLVMIR translation
290294
// infrastructure can pick them up.
291-
m.walk([ctx](LLVM::LLVMFuncOp op) {
295+
m.walk([&](LLVM::LLVMFuncOp op) {
292296
if (auto blockSizes = dyn_cast_or_null<DenseI32ArrayAttr>(
293297
op->removeAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName()))) {
294-
op->setAttr(ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName(),
295-
blockSizes);
298+
reqdWorkGroupSizeAttrHelper.setAttr(op, blockSizes);
296299
// Also set up the rocdl.flat_work_group_size attribute to prevent
297300
// conflicting metadata.
298301
uint32_t flatSize = 1;
@@ -301,8 +304,7 @@ struct LowerGpuOpsToROCDLOpsPass
301304
}
302305
StringAttr flatSizeAttr =
303306
StringAttr::get(ctx, Twine(flatSize) + "," + Twine(flatSize));
304-
op->setAttr(ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName(),
305-
flatSizeAttr);
307+
flatWorkGroupSizeAttrHelper.setAttr(op, flatSizeAttr);
306308
}
307309
});
308310
}
@@ -355,8 +357,7 @@ void mlir::populateGpuToROCDLConversionPatterns(
355357
converter,
356358
/*allocaAddrSpace=*/ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace,
357359
/*workgroupAddrSpace=*/ROCDL::ROCDLDialect::kSharedMemoryAddressSpace,
358-
StringAttr::get(&converter.getContext(),
359-
ROCDL::ROCDLDialect::getKernelFuncAttrName()));
360+
ROCDL::ROCDLDialect::KernelAttrHelper(&converter.getContext()).getName());
360361
if (Runtime::HIP == runtime) {
361362
patterns.add<GPUPrintfOpToHIPLowering>(converter);
362363
} else if (Runtime::OpenCL == runtime) {

mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,9 @@ void ROCDLDialect::initialize() {
253253
LogicalResult ROCDLDialect::verifyOperationAttribute(Operation *op,
254254
NamedAttribute attr) {
255255
// Kernel function attribute should be attached to functions.
256-
if (attr.getName() == ROCDLDialect::getKernelFuncAttrName()) {
256+
if (kernelAttrName.getName() == attr.getName()) {
257257
if (!isa<LLVM::LLVMFuncOp>(op)) {
258-
return op->emitError() << "'" << ROCDLDialect::getKernelFuncAttrName()
258+
return op->emitError() << "'" << kernelAttrName.getName()
259259
<< "' attribute attached to unexpected op";
260260
}
261261
}

mlir/lib/TableGen/Dialect.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ bool Dialect::usePropertiesForAttributes() const {
106106
return def->getValueAsBit("usePropertiesForAttributes");
107107
}
108108

109+
llvm::DagInit *Dialect::getDiscardableAttributes() const {
110+
return def->getValueAsDag("discardableAttrs");
111+
}
112+
109113
bool Dialect::operator==(const Dialect &other) const {
110114
return def == other.def;
111115
}

mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ class ROCDLDialectLLVMIRTranslationInterface
8484
amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
8585
NamedAttribute attribute,
8686
LLVM::ModuleTranslation &moduleTranslation) const final {
87-
if (attribute.getName() == ROCDL::ROCDLDialect::getKernelFuncAttrName()) {
87+
auto *dialect = dyn_cast<ROCDL::ROCDLDialect>(attribute.getNameDialect());
88+
if (dialect->getKernelAttrHelper().getName() == attribute.getName()) {
8889
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
8990
if (!func)
9091
return failure();
@@ -106,7 +107,8 @@ class ROCDLDialectLLVMIRTranslationInterface
106107
// Override flat-work-group-size
107108
// TODO: update clients to rocdl.flat_work_group_size instead,
108109
// then remove this half of the branch
109-
if ("rocdl.max_flat_work_group_size" == attribute.getName()) {
110+
if (dialect->getMaxFlatWorkGroupSizeAttrHelper().getName() ==
111+
attribute.getName()) {
110112
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
111113
if (!func)
112114
return failure();
@@ -121,7 +123,7 @@ class ROCDLDialectLLVMIRTranslationInterface
121123
attrValueStream << "1," << value.getInt();
122124
llvmFunc->addFnAttr("amdgpu-flat-work-group-size", llvmAttrValue);
123125
}
124-
if (ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName() ==
126+
if (dialect->getFlatWorkGroupSizeAttrHelper().getName() ==
125127
attribute.getName()) {
126128
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
127129
if (!func)
@@ -138,7 +140,7 @@ class ROCDLDialectLLVMIRTranslationInterface
138140
}
139141

140142
// Set reqd_work_group_size metadata
141-
if (ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName() ==
143+
if (dialect->getReqdWorkGroupSizeAttrHelper().getName() ==
142144
attribute.getName()) {
143145
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
144146
if (!func)

mlir/test/lib/Dialect/Test/TestDialect.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ def Test_Dialect : Dialect {
2525
let useDefaultAttributePrinterParser = 1;
2626
let isExtensible = 1;
2727
let dependentDialects = ["::mlir::DLTIDialect"];
28+
let discardableAttrs = (ins
29+
"mlir::IntegerAttr":$discardable_attr_key,
30+
"SimpleAAttr":$other_discardable_attr_key
31+
);
2832

2933
let extraClassDeclaration = [{
3034
void registerAttributes();

mlir/tools/mlir-tblgen/DialectGen.cpp

Lines changed: 85 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,21 @@ using DialectFilterIterator =
4343
std::function<bool(const llvm::Record *)>>;
4444
} // namespace
4545

46+
static void populateDiscardableAttributes(
47+
Dialect &dialect, llvm::DagInit *discardableAttrDag,
48+
SmallVector<std::pair<std::string, std::string>> &discardableAttributes) {
49+
for (int i : llvm::seq<int>(0, discardableAttrDag->getNumArgs())) {
50+
llvm::Init *arg = discardableAttrDag->getArg(i);
51+
52+
StringRef givenName = discardableAttrDag->getArgNameStr(i);
53+
if (givenName.empty())
54+
PrintFatalError(dialect.getDef()->getLoc(),
55+
"discardable attributes must be named");
56+
discardableAttributes.push_back(
57+
{givenName.str(), arg->getAsUnquotedString()});
58+
}
59+
}
60+
4661
/// Given a set of records for a T, filter the ones that correspond to
4762
/// the given dialect.
4863
template <typename T>
@@ -181,6 +196,37 @@ static const char *const operationInterfaceFallbackDecl = R"(
181196
mlir::OperationName opName) override;
182197
)";
183198

199+
/// The code block for the discardable attribute helper.
200+
static const char *const discardableAttrHelperDecl = R"(
201+
/// Helper to manage the discardable attribute `{1}`.
202+
class {0}AttrHelper {{
203+
mlir::StringAttr name;
204+
public:
205+
static constexpr llvm::StringLiteral getNameStr() {{
206+
return "{4}.{1}";
207+
}
208+
constexpr mlir::StringAttr getName() {{
209+
return name;
210+
}
211+
212+
{0}AttrHelper(mlir::MLIRContext *ctx)
213+
: name(mlir::StringAttr::get(ctx, getNameStr())) {{}
214+
215+
{2} getAttr(::mlir::Operation *op) {{
216+
return op->getAttrOfType<{2}>(getName());
217+
}
218+
void setAttr(::mlir::Operation *op, {2} val) {{
219+
op->setAttr(getName(), val);
220+
}
221+
};
222+
{0}AttrHelper get{0}AttrHelper() {
223+
return {3}AttrName;
224+
}
225+
private:
226+
{0}AttrHelper {3}AttrName;
227+
public:
228+
)";
229+
184230
/// Generate the declaration for the given dialect class.
185231
static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
186232
// Emit all nested namespaces.
@@ -216,6 +262,22 @@ static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
216262
os << regionResultAttrVerifierDecl;
217263
if (dialect.hasOperationInterfaceFallback())
218264
os << operationInterfaceFallbackDecl;
265+
266+
llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
267+
SmallVector<std::pair<std::string, std::string>> discardableAttributes;
268+
populateDiscardableAttributes(dialect, discardableAttrDag,
269+
discardableAttributes);
270+
271+
for (const auto &attrPair : discardableAttributes) {
272+
std::string camelNameUpper = llvm::convertToCamelFromSnakeCase(
273+
attrPair.first, /*capitalizeFirst=*/true);
274+
std::string camelName = llvm::convertToCamelFromSnakeCase(
275+
attrPair.first, /*capitalizeFirst=*/false);
276+
os << llvm::formatv(discardableAttrHelperDecl, camelNameUpper,
277+
attrPair.first, attrPair.second, camelName,
278+
dialect.getName());
279+
}
280+
219281
if (std::optional<StringRef> extraDecl = dialect.getExtraClassDeclaration())
220282
os << *extraDecl;
221283

@@ -253,9 +315,12 @@ static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
253315
/// {1}: initialization code that is emitted in the ctor body before calling
254316
/// initialize().
255317
/// {2}: The dialect parent class.
318+
/// {3}: Extra members to initialize
256319
static const char *const dialectConstructorStr = R"(
257320
{0}::{0}(::mlir::MLIRContext *context)
258-
: ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) {{
321+
: ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>())
322+
{3}
323+
{{
259324
{1}
260325
initialize();
261326
}
@@ -269,7 +334,9 @@ static const char *const dialectDestructorStr = R"(
269334
270335
)";
271336

272-
static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
337+
static void emitDialectDef(Dialect &dialect,
338+
const llvm::RecordKeeper &recordKeeper,
339+
raw_ostream &os) {
273340
std::string cppClassName = dialect.getCppClassName();
274341

275342
// Emit the TypeID explicit specializations to have a single symbol def.
@@ -292,8 +359,22 @@ static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
292359
// Emit the constructor and destructor.
293360
StringRef superClassName =
294361
dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
362+
363+
llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
364+
SmallVector<std::pair<std::string, std::string>> discardableAttributes;
365+
populateDiscardableAttributes(dialect, discardableAttrDag,
366+
discardableAttributes);
367+
std::string discardableAttributesInit;
368+
for (const auto &attrPair : discardableAttributes) {
369+
std::string camelName = llvm::convertToCamelFromSnakeCase(
370+
attrPair.first, /*capitalizeFirst=*/false);
371+
llvm::raw_string_ostream os(discardableAttributesInit);
372+
os << ", " << camelName << "AttrName(context)";
373+
}
374+
295375
os << llvm::formatv(dialectConstructorStr, cppClassName,
296-
dependentDialectRegistrations, superClassName);
376+
dependentDialectRegistrations, superClassName,
377+
discardableAttributesInit);
297378
if (!dialect.hasNonDefaultDestructor())
298379
os << llvm::formatv(dialectDestructorStr, cppClassName);
299380
}
@@ -310,7 +391,7 @@ static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper,
310391
std::optional<Dialect> dialect = findDialectToGenerate(dialects);
311392
if (!dialect)
312393
return true;
313-
emitDialectDef(*dialect, os);
394+
emitDialectDef(*dialect, recordKeeper, os);
314395
return false;
315396
}
316397

0 commit comments

Comments
 (0)