Skip to content

[MLIR] Add ODS support for generating helpers for dialect (discardable) attributes #77024

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,6 @@ def ROCDL_Dialect : Dialect {
let hasOperationAttrVerify = 1;

let extraClassDeclaration = [{
/// Get the name of the attribute used to annotate external kernel
/// functions.
static StringRef getKernelFuncAttrName() { return "rocdl.kernel"; }
static constexpr ::llvm::StringLiteral getFlatWorkGroupSizeAttrName() {
return ::llvm::StringLiteral("rocdl.flat_work_group_size");
}
static constexpr ::llvm::StringLiteral getReqdWorkGroupSizeAttrName() {
return ::llvm::StringLiteral("rocdl.reqd_work_group_size");
}

/// The address space value that represents global memory.
static constexpr unsigned kGlobalMemoryAddressSpace = 1;
/// The address space value that represents shared memory.
Expand All @@ -46,6 +36,14 @@ def ROCDL_Dialect : Dialect {
static constexpr unsigned kPrivateMemoryAddressSpace = 5;
}];

let discardableAttrs = (ins
"::mlir::UnitAttr":$kernel,
"::mlir::DenseI32ArrayAttr":$reqd_work_group_size,
"::mlir::StringAttr":$flat_work_group_size,
"::mlir::IntegerAttr":$max_flat_work_group_size,
"::mlir::IntegerAttr":$waves_per_eu
);

let useDefaultAttributePrinterParser = 1;
}

Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/IR/DialectBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ class Dialect {
// pattern or interfaces.
list<string> dependentDialects = [];

// A list of key/value pair representing an attribute type and a name.
// This will generate helper classes on the dialect to be able to
// manage discardable attributes on operations in a type safe way.
dag discardableAttrs = (ins);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is my mental model correct that these are all effectively optional attributes for which there are helpers generated? But they are different than the others as not verified, most likely empty? (e.g., I'm not sure if doing this vs making these optional attributes when one would select which ...)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if I followed correctly what you're asking about, but these are helpers to manage discardable attributes. The idea is that for example the ROCDL dialect involves a discardable attribute with the key rocdl.reqd_work_group_size, and the associated Attribute must be a DenseI32ArrayAttr.
This feature will generate a helper for managing this discardable attribute on operations. The operations don't know anything about the attribute, it's a property of the dialect itself.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense to me. These being dialect known discardable rather than op known optional.


// The C++ namespace that ops of this dialect should be placed into.
//
// By default, uses the name of the dialect as the only namespace. To avoid
Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/TableGen/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#define MLIR_TABLEGEN_DIALECT_H_

#include "mlir/Support/LLVM.h"
#include "llvm/TableGen/Record.h"

#include <string>
#include <vector>

Expand Down Expand Up @@ -90,6 +92,10 @@ class Dialect {
/// dialect.
bool usePropertiesForAttributes() const;

llvm::DagInit *getDiscardableAttributes() const;

const llvm::Record *getDef() const { return def; }

// Returns whether two dialects are equal by checking the equality of the
// underlying record.
bool operator==(const Dialect &other) const;
Expand Down
17 changes: 9 additions & 8 deletions mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,14 +285,17 @@ struct LowerGpuOpsToROCDLOpsPass
configureGpuToROCDLConversionLegality(target);
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
signalPassFailure();

auto *rocdlDialect = getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
auto reqdWorkGroupSizeAttrHelper =
rocdlDialect->getReqdWorkGroupSizeAttrHelper();
auto flatWorkGroupSizeAttrHelper =
rocdlDialect->getFlatWorkGroupSizeAttrHelper();
// Manually rewrite known block size attributes so the LLVMIR translation
// infrastructure can pick them up.
m.walk([ctx](LLVM::LLVMFuncOp op) {
m.walk([&](LLVM::LLVMFuncOp op) {
if (auto blockSizes = dyn_cast_or_null<DenseI32ArrayAttr>(
op->removeAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName()))) {
op->setAttr(ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName(),
blockSizes);
reqdWorkGroupSizeAttrHelper.setAttr(op, blockSizes);
// Also set up the rocdl.flat_work_group_size attribute to prevent
// conflicting metadata.
uint32_t flatSize = 1;
Expand All @@ -301,8 +304,7 @@ struct LowerGpuOpsToROCDLOpsPass
}
StringAttr flatSizeAttr =
StringAttr::get(ctx, Twine(flatSize) + "," + Twine(flatSize));
op->setAttr(ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName(),
flatSizeAttr);
flatWorkGroupSizeAttrHelper.setAttr(op, flatSizeAttr);
}
});
}
Expand Down Expand Up @@ -355,8 +357,7 @@ void mlir::populateGpuToROCDLConversionPatterns(
converter,
/*allocaAddrSpace=*/ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace,
/*workgroupAddrSpace=*/ROCDL::ROCDLDialect::kSharedMemoryAddressSpace,
StringAttr::get(&converter.getContext(),
ROCDL::ROCDLDialect::getKernelFuncAttrName()));
ROCDL::ROCDLDialect::KernelAttrHelper(&converter.getContext()).getName());
if (Runtime::HIP == runtime) {
patterns.add<GPUPrintfOpToHIPLowering>(converter);
} else if (Runtime::OpenCL == runtime) {
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,9 @@ void ROCDLDialect::initialize() {
LogicalResult ROCDLDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
// Kernel function attribute should be attached to functions.
if (attr.getName() == ROCDLDialect::getKernelFuncAttrName()) {
if (kernelAttrName.getName() == attr.getName()) {
if (!isa<LLVM::LLVMFuncOp>(op)) {
return op->emitError() << "'" << ROCDLDialect::getKernelFuncAttrName()
return op->emitError() << "'" << kernelAttrName.getName()
<< "' attribute attached to unexpected op";
}
}
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/TableGen/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ bool Dialect::usePropertiesForAttributes() const {
return def->getValueAsBit("usePropertiesForAttributes");
}

llvm::DagInit *Dialect::getDiscardableAttributes() const {
return def->getValueAsDag("discardableAttrs");
}

bool Dialect::operator==(const Dialect &other) const {
return def == other.def;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ class ROCDLDialectLLVMIRTranslationInterface
amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const final {
if (attribute.getName() == ROCDL::ROCDLDialect::getKernelFuncAttrName()) {
auto *dialect = dyn_cast<ROCDL::ROCDLDialect>(attribute.getNameDialect());
if (dialect->getKernelAttrHelper().getName() == attribute.getName()) {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func)
return failure();
Expand All @@ -99,12 +100,12 @@ class ROCDLDialectLLVMIRTranslationInterface
if (!llvmFunc->hasFnAttribute("amdgpu-flat-work-group-size")) {
llvmFunc->addFnAttr("amdgpu-flat-work-group-size", "1,256");
}

}
// Override flat-work-group-size
// TODO: update clients to rocdl.flat_work_group_size instead,
// then remove this half of the branch
if ("rocdl.max_flat_work_group_size" == attribute.getName()) {
if (dialect->getMaxFlatWorkGroupSizeAttrHelper().getName() ==
attribute.getName()) {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func)
return failure();
Expand All @@ -119,7 +120,7 @@ class ROCDLDialectLLVMIRTranslationInterface
attrValueStream << "1," << value.getInt();
llvmFunc->addFnAttr("amdgpu-flat-work-group-size", llvmAttrValue);
}
if (ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName() ==
if (dialect->getFlatWorkGroupSizeAttrHelper().getName() ==
attribute.getName()) {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func)
Expand All @@ -136,7 +137,7 @@ class ROCDLDialectLLVMIRTranslationInterface
}

// Set reqd_work_group_size metadata
if (ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName() ==
if (dialect->getReqdWorkGroupSizeAttrHelper().getName() ==
attribute.getName()) {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func)
Expand Down
4 changes: 4 additions & 0 deletions mlir/test/lib/Dialect/Test/TestDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def Test_Dialect : Dialect {
let useDefaultAttributePrinterParser = 1;
let isExtensible = 1;
let dependentDialects = ["::mlir::DLTIDialect"];
let discardableAttrs = (ins
"mlir::IntegerAttr":$discardable_attr_key,
"SimpleAAttr":$other_discardable_attr_key
);

let extraClassDeclaration = [{
void registerAttributes();
Expand Down
96 changes: 92 additions & 4 deletions mlir/tools/mlir-tblgen/DialectGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,21 @@ using DialectFilterIterator =
std::function<bool(const llvm::Record *)>>;
} // namespace

static void populateDiscardableAttributes(
Dialect &dialect, llvm::DagInit *discardableAttrDag,
SmallVector<std::pair<std::string, std::string>> &discardableAttributes) {
for (int i : llvm::seq<int>(0, discardableAttrDag->getNumArgs())) {
llvm::Init *arg = discardableAttrDag->getArg(i);

StringRef givenName = discardableAttrDag->getArgNameStr(i);
if (givenName.empty())
PrintFatalError(dialect.getDef()->getLoc(),
"discardable attributes must be named");
discardableAttributes.push_back(
{givenName.str(), arg->getAsUnquotedString()});
}
}

/// Given a set of records for a T, filter the ones that correspond to
/// the given dialect.
template <typename T>
Expand Down Expand Up @@ -180,6 +195,44 @@ static const char *const operationInterfaceFallbackDecl = R"(
mlir::OperationName opName) override;
)";

/// The code block for the discardable attribute helper.
static const char *const discardableAttrHelperDecl = R"(
/// Helper to manage the discardable attribute `{1}`.
class {0}AttrHelper {{
::mlir::StringAttr name;
public:
static constexpr ::llvm::StringLiteral getNameStr() {{
return "{4}.{1}";
}
constexpr ::mlir::StringAttr getName() {{
return name;
}

{0}AttrHelper(::mlir::MLIRContext *ctx)
: name(::mlir::StringAttr::get(ctx, getNameStr())) {{}

{2} getAttr(::mlir::Operation *op) {{
return op->getAttrOfType<{2}>(name);
}
void setAttr(::mlir::Operation *op, {2} val) {{
op->setAttr(name, val);
}
bool isAttrPresent(::mlir::Operation *op) {{
return op->hasAttrOfType<{2}>(name);
}
void removeAttr(::mlir::Operation *op) {{
assert(op->hasAttrOfType<{2}>(name));
op->removeAttr(name);
}
};
{0}AttrHelper get{0}AttrHelper() {
return {3}AttrName;
}
private:
{0}AttrHelper {3}AttrName;
public:
)";

/// Generate the declaration for the given dialect class.
static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
// Emit all nested namespaces.
Expand Down Expand Up @@ -215,6 +268,22 @@ static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
os << regionResultAttrVerifierDecl;
if (dialect.hasOperationInterfaceFallback())
os << operationInterfaceFallbackDecl;

llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
SmallVector<std::pair<std::string, std::string>> discardableAttributes;
populateDiscardableAttributes(dialect, discardableAttrDag,
discardableAttributes);

for (const auto &attrPair : discardableAttributes) {
std::string camelNameUpper = llvm::convertToCamelFromSnakeCase(
attrPair.first, /*capitalizeFirst=*/true);
std::string camelName = llvm::convertToCamelFromSnakeCase(
attrPair.first, /*capitalizeFirst=*/false);
os << llvm::formatv(discardableAttrHelperDecl, camelNameUpper,
attrPair.first, attrPair.second, camelName,
dialect.getName());
}

if (std::optional<StringRef> extraDecl = dialect.getExtraClassDeclaration())
os << *extraDecl;

Expand Down Expand Up @@ -252,9 +321,12 @@ static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
/// {1}: Initialization code that is emitted in the ctor body before calling
/// initialize(), such as dependent dialect registration.
/// {2}: The dialect parent class.
/// {3}: Extra members to initialize
static const char *const dialectConstructorStr = R"(
{0}::{0}(::mlir::MLIRContext *context)
: ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) {{
: ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>())
{3}
{{
{1}
initialize();
}
Expand All @@ -268,7 +340,9 @@ static const char *const dialectDestructorStr = R"(

)";

static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
static void emitDialectDef(Dialect &dialect,
const llvm::RecordKeeper &recordKeeper,
raw_ostream &os) {
std::string cppClassName = dialect.getCppClassName();

// Emit the TypeID explicit specializations to have a single symbol def.
Expand All @@ -295,8 +369,22 @@ static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
// Emit the constructor and destructor.
StringRef superClassName =
dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";

llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
SmallVector<std::pair<std::string, std::string>> discardableAttributes;
populateDiscardableAttributes(dialect, discardableAttrDag,
discardableAttributes);
std::string discardableAttributesInit;
for (const auto &attrPair : discardableAttributes) {
std::string camelName = llvm::convertToCamelFromSnakeCase(
attrPair.first, /*capitalizeFirst=*/false);
llvm::raw_string_ostream os(discardableAttributesInit);
os << ", " << camelName << "AttrName(context)";
}

os << llvm::formatv(dialectConstructorStr, cppClassName,
dependentDialectRegistrations, superClassName);
dependentDialectRegistrations, superClassName,
discardableAttributesInit);
if (!dialect.hasNonDefaultDestructor())
os << llvm::formatv(dialectDestructorStr, cppClassName);
}
Expand All @@ -313,7 +401,7 @@ static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper,
std::optional<Dialect> dialect = findDialectToGenerate(dialects);
if (!dialect)
return true;
emitDialectDef(*dialect, os);
emitDialectDef(*dialect, recordKeeper, os);
return false;
}

Expand Down