Skip to content

[MLIR][LLVM] Attach kernel metadata representation to llvm.func #101314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,23 @@ def LLVM_UndefAttr : LLVM_Attr<"Undef", "undef">;
/// Folded into from LLVM::PoisonOp.
def LLVM_PoisonAttr : LLVM_Attr<"Poison", "poison">;

//===----------------------------------------------------------------------===//
// VecTypeHintAttr
//===----------------------------------------------------------------------===//

def LLVM_VecTypeHintAttr : LLVM_Attr<"VecTypeHint", "vec_type_hint"> {
let summary = "Explicit vectorization compiler hint";
let description = [{
A hint to the compiler that indicates most operations used in the function
are explictly vectorized using a particular vector type. `$hint` is the
vector or scalar type in particular. `$is_signed` can be used with integer
types to state whether the type is signed.
}];
let parameters = (ins "TypeAttr":$hint,
DefaultValuedParameter<"bool", "false">:$is_signed);
let assemblyFormat = "`<` struct(params) `>`";
}

//===----------------------------------------------------------------------===//
// ZeroAttr
//===----------------------------------------------------------------------===//
Expand Down
6 changes: 5 additions & 1 deletion mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1456,7 +1456,11 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
OptionalAttr<UnitAttr>:$always_inline,
OptionalAttr<UnitAttr>:$no_unwind,
OptionalAttr<UnitAttr>:$will_return,
OptionalAttr<UnitAttr>:$optimize_none
OptionalAttr<UnitAttr>:$optimize_none,
OptionalAttr<LLVM_VecTypeHintAttr>:$vec_type_hint,
OptionalAttr<DenseI32ArrayAttr>:$work_group_size_hint,
OptionalAttr<DenseI32ArrayAttr>:$reqd_work_group_size,
OptionalAttr<I32Attr>:$intel_reqd_sub_group_size
);

let regions = (region AnyRegion:$body);
Expand Down
12 changes: 8 additions & 4 deletions mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,12 @@ class LLVMImportDialectInterface

/// Hook for derived dialect interfaces to publish the supported metadata
/// kinds. As every metadata kind has a unique integer identifier, the
/// function returns the list of supported metadata identifiers.
virtual ArrayRef<unsigned> getSupportedMetadata() const { return {}; }
/// function returns the list of supported metadata identifiers. `ctx` can be
/// used to obtain IDs of metadata kinds that do not have a fixed static one.
virtual ArrayRef<unsigned>
getSupportedMetadata(llvm::LLVMContext &ctx) const {
return {};
}
};

/// Interface collection for the import of LLVM IR that dispatches to a concrete
Expand All @@ -101,7 +105,7 @@ class LLVMImportInterface
/// intrinsic and metadata kinds and builds the dispatch tables for the
/// conversion. Returns failure if multiple dialect interfaces translate the
/// same LLVM IR intrinsic.
LogicalResult initializeImport() {
LogicalResult initializeImport(llvm::LLVMContext &llvmContext) {
for (const LLVMImportDialectInterface &iface : *this) {
// Verify the supported intrinsics have not been mapped before.
const auto *intrinsicIt =
Expand Down Expand Up @@ -139,7 +143,7 @@ class LLVMImportInterface
for (unsigned id : iface.getSupportedInstructions())
instructionToDialect[id] = &iface;
// Add a mapping for all supported metadata kinds.
for (unsigned kind : iface.getSupportedMetadata())
for (unsigned kind : iface.getSupportedMetadata(llvmContext))
metadataToDialect[kind].push_back(iface.getDialect());
}

Expand Down
4 changes: 3 additions & 1 deletion mlir/include/mlir/Target/LLVMIR/ModuleImport.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ class ModuleImport {
/// dialect interfaces for the supported LLVM IR intrinsics and metadata kinds
/// and builds the dispatch tables. Returns failure if multiple dialect
/// interfaces translate the same LLVM IR intrinsic.
LogicalResult initializeImportInterface() { return iface.initializeImport(); }
LogicalResult initializeImportInterface() {
return iface.initializeImport(llvmModule->getContext());
}

/// Converts all functions of the LLVM module to MLIR functions.
LogicalResult convertFunctions();
Expand Down
158 changes: 152 additions & 6 deletions mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ using namespace mlir::LLVM::detail;

#include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc"

static constexpr StringLiteral vecTypeHintMDName = "vec_type_hint";
static constexpr StringLiteral workGroupSizeHintMDName = "work_group_size_hint";
static constexpr StringLiteral reqdWorkGroupSizeMDName = "reqd_work_group_size";
static constexpr StringLiteral intelReqdSubGroupSizeMDName =
"intel_reqd_sub_group_size";

/// Returns true if the LLVM IR intrinsic is convertible to an MLIR LLVM dialect
/// intrinsic. Returns false otherwise.
static bool isConvertibleIntrinsic(llvm::Intrinsic::ID id) {
Expand Down Expand Up @@ -70,11 +76,18 @@ static LogicalResult convertIntrinsicImpl(OpBuilder &odsBuilder,

/// Returns the list of LLVM IR metadata kinds that are convertible to MLIR LLVM
/// dialect attributes.
static ArrayRef<unsigned> getSupportedMetadataImpl() {
static ArrayRef<unsigned> getSupportedMetadataImpl(llvm::LLVMContext &context) {
static const SmallVector<unsigned> convertibleMetadata = {
llvm::LLVMContext::MD_prof, llvm::LLVMContext::MD_tbaa,
llvm::LLVMContext::MD_access_group, llvm::LLVMContext::MD_loop,
llvm::LLVMContext::MD_noalias, llvm::LLVMContext::MD_alias_scope};
llvm::LLVMContext::MD_prof,
llvm::LLVMContext::MD_tbaa,
llvm::LLVMContext::MD_access_group,
llvm::LLVMContext::MD_loop,
llvm::LLVMContext::MD_noalias,
llvm::LLVMContext::MD_alias_scope,
context.getMDKindID(vecTypeHintMDName),
context.getMDKindID(workGroupSizeHintMDName),
context.getMDKindID(reqdWorkGroupSizeMDName),
context.getMDKindID(intelReqdSubGroupSizeMDName)};
return convertibleMetadata;
}

Expand Down Expand Up @@ -226,6 +239,128 @@ static LogicalResult setNoaliasScopesAttr(const llvm::MDNode *node,
return success();
}

/// Extracts an integer from the provided metadata `md` if possible. Returns
/// nullopt otherwise.
static std::optional<int32_t> parseIntegerMD(llvm::Metadata *md) {
auto *constant = dyn_cast_if_present<llvm::ConstantAsMetadata>(md);
if (!constant)
return {};

auto *intConstant = dyn_cast<llvm::ConstantInt>(constant->getValue());
if (!intConstant)
return {};

return intConstant->getValue().getSExtValue();
}

/// Converts the provided metadata node `node` to an LLVM dialect
/// VecTypeHintAttr if possible.
static VecTypeHintAttr convertVecTypeHint(Builder builder, llvm::MDNode *node,
ModuleImport &moduleImport) {
if (!node || node->getNumOperands() != 2)
return {};

auto *hintMD = dyn_cast<llvm::ValueAsMetadata>(node->getOperand(0).get());
if (!hintMD)
return {};
TypeAttr hint = TypeAttr::get(moduleImport.convertType(hintMD->getType()));

std::optional<int32_t> optIsSigned =
parseIntegerMD(node->getOperand(1).get());
if (!optIsSigned)
return {};
bool isSigned = *optIsSigned != 0;

return builder.getAttr<VecTypeHintAttr>(hint, isSigned);
}

/// Converts the provided metadata node `node` to an MLIR DenseI32ArrayAttr if
/// possible.
static DenseI32ArrayAttr convertDenseI32Array(Builder builder,
llvm::MDNode *node) {
if (!node)
return {};
SmallVector<int32_t> vals;
for (const llvm::MDOperand &op : node->operands()) {
std::optional<int32_t> mdValue = parseIntegerMD(op.get());
if (!mdValue)
return {};
vals.push_back(*mdValue);
}
return builder.getDenseI32ArrayAttr(vals);
}

/// Convert an `MDNode` to an MLIR `IntegerAttr` if possible.
static IntegerAttr convertIntegerMD(Builder builder, llvm::MDNode *node) {
if (!node || node->getNumOperands() != 1)
return {};
std::optional<int32_t> val = parseIntegerMD(node->getOperand(0));
if (!val)
return {};
return builder.getI32IntegerAttr(*val);
}

static LogicalResult setVecTypeHintAttr(Builder &builder, llvm::MDNode *node,
Operation *op,
LLVM::ModuleImport &moduleImport) {
auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!funcOp)
return failure();

VecTypeHintAttr attr = convertVecTypeHint(builder, node, moduleImport);
if (!attr)
return failure();

funcOp.setVecTypeHintAttr(attr);
return success();
}

static LogicalResult
setWorkGroupSizeHintAttr(Builder &builder, llvm::MDNode *node, Operation *op) {
auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!funcOp)
return failure();

DenseI32ArrayAttr attr = convertDenseI32Array(builder, node);
if (!attr)
return failure();

funcOp.setWorkGroupSizeHintAttr(attr);
return success();
}

static LogicalResult
setReqdWorkGroupSizeAttr(Builder &builder, llvm::MDNode *node, Operation *op) {
auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!funcOp)
return failure();

DenseI32ArrayAttr attr = convertDenseI32Array(builder, node);
if (!attr)
return failure();

funcOp.setReqdWorkGroupSizeAttr(attr);
return success();
}

/// Converts the given intel required subgroup size metadata node to an MLIR
/// attribute and attaches it to the imported operation if the translation
/// succeeds. Returns failure otherwise.
static LogicalResult setIntelReqdSubGroupSizeAttr(Builder &builder,
llvm::MDNode *node,
Operation *op) {
auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!funcOp)
return failure();

IntegerAttr attr = convertIntegerMD(builder, node);
if (!attr)
return failure();

funcOp.setIntelReqdSubGroupSizeAttr(attr);
return success();
}

namespace {

/// Implementation of the dialect interface that converts operations belonging
Expand Down Expand Up @@ -261,6 +396,16 @@ class LLVMDialectLLVMIRImportInterface : public LLVMImportDialectInterface {
if (kind == llvm::LLVMContext::MD_noalias)
return setNoaliasScopesAttr(node, op, moduleImport);

llvm::LLVMContext &context = node->getContext();
if (kind == context.getMDKindID(vecTypeHintMDName))
return setVecTypeHintAttr(builder, node, op, moduleImport);
if (kind == context.getMDKindID(workGroupSizeHintMDName))
return setWorkGroupSizeHintAttr(builder, node, op);
if (kind == context.getMDKindID(reqdWorkGroupSizeMDName))
return setReqdWorkGroupSizeAttr(builder, node, op);
if (kind == context.getMDKindID(intelReqdSubGroupSizeMDName))
return setIntelReqdSubGroupSizeAttr(builder, node, op);

// A handler for a supported metadata kind is missing.
llvm_unreachable("unknown metadata type");
}
Expand All @@ -273,8 +418,9 @@ class LLVMDialectLLVMIRImportInterface : public LLVMImportDialectInterface {

/// Returns the list of LLVM IR metadata kinds that are convertible to MLIR
/// LLVM dialect attributes.
ArrayRef<unsigned> getSupportedMetadata() const final {
return getSupportedMetadataImpl();
ArrayRef<unsigned>
getSupportedMetadata(llvm::LLVMContext &context) const final {
return getSupportedMetadataImpl(context);
}
};
} // namespace
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Target/LLVMIR/ModuleImport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,10 +500,10 @@ LogicalResult ModuleImport::convertLinkerOptionsMetadata() {
if (named.getName() != "llvm.linker.options")
continue;
// llvm.linker.options operands are lists of strings.
for (const llvm::MDNode *md : named.operands()) {
for (const llvm::MDNode *node : named.operands()) {
SmallVector<StringRef> options;
options.reserve(md->getNumOperands());
for (const llvm::MDOperand &option : md->operands())
options.reserve(node->getNumOperands());
for (const llvm::MDOperand &option : node->operands())
options.push_back(cast<llvm::MDString>(option)->getString());
builder.create<LLVM::LinkerOptionsOp>(mlirModule.getLoc(),
builder.getStrArrayAttr(options));
Expand Down
76 changes: 76 additions & 0 deletions mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1247,6 +1247,41 @@ static LogicalResult checkedAddLLVMFnAttribute(Location loc,
return success();
}

/// Return a representation of `value` as metadata.
static llvm::Metadata *convertIntegerToMetadata(llvm::LLVMContext &context,
const llvm::APInt &value) {
llvm::Constant *constant = llvm::ConstantInt::get(context, value);
return llvm::ConstantAsMetadata::get(constant);
}

/// Return a representation of `value` as an MDNode.
static llvm::MDNode *convertIntegerToMDNode(llvm::LLVMContext &context,
const llvm::APInt &value) {
return llvm::MDNode::get(context, convertIntegerToMetadata(context, value));
}

/// Return an MDNode encoding `vec_type_hint` metadata.
static llvm::MDNode *convertVecTypeHintToMDNode(llvm::LLVMContext &context,
llvm::Type *type,
bool isSigned) {
llvm::Metadata *typeMD =
llvm::ConstantAsMetadata::get(llvm::UndefValue::get(type));
llvm::Metadata *isSignedMD =
convertIntegerToMetadata(context, llvm::APInt(32, isSigned ? 1 : 0));
return llvm::MDNode::get(context, {typeMD, isSignedMD});
}

/// Return an MDNode with a tuple given by the values in `values`.
static llvm::MDNode *convertIntegerArrayToMDNode(llvm::LLVMContext &context,
ArrayRef<int32_t> values) {
SmallVector<llvm::Metadata *> mdValues;
llvm::transform(
values, std::back_inserter(mdValues), [&context](int32_t value) {
return convertIntegerToMetadata(context, llvm::APInt(32, value));
});
return llvm::MDNode::get(context, mdValues);
}

/// Attaches the attributes listed in the given array attribute to `llvmFunc`.
/// Reports error to `loc` if any and returns immediately. Expects `attributes`
/// to be an array attribute containing either string attributes, treated as
Expand Down Expand Up @@ -1448,6 +1483,44 @@ static void convertFunctionAttributes(LLVMFuncOp func,
convertFunctionMemoryAttributes(func, llvmFunc);
}

/// Converts function attributes from `func` and attaches them to `llvmFunc`.
static void convertFunctionKernelAttributes(LLVMFuncOp func,
llvm::Function *llvmFunc,
ModuleTranslation &translation) {
llvm::LLVMContext &llvmContext = llvmFunc->getContext();

if (VecTypeHintAttr vecTypeHint = func.getVecTypeHintAttr()) {
Type type = vecTypeHint.getHint().getValue();
llvm::Type *llvmType = translation.convertType(type);
bool isSigned = vecTypeHint.getIsSigned();
llvmFunc->setMetadata(
func.getVecTypeHintAttrName(),
convertVecTypeHintToMDNode(llvmContext, llvmType, isSigned));
}

if (std::optional<ArrayRef<int32_t>> workGroupSizeHint =
func.getWorkGroupSizeHint()) {
llvmFunc->setMetadata(
func.getWorkGroupSizeHintAttrName(),
convertIntegerArrayToMDNode(llvmContext, *workGroupSizeHint));
}

if (std::optional<ArrayRef<int32_t>> reqdWorkGroupSize =
func.getReqdWorkGroupSize()) {
llvmFunc->setMetadata(
func.getReqdWorkGroupSizeAttrName(),
convertIntegerArrayToMDNode(llvmContext, *reqdWorkGroupSize));
}

if (std::optional<uint32_t> intelReqdSubGroupSize =
func.getIntelReqdSubGroupSize()) {
llvmFunc->setMetadata(
func.getIntelReqdSubGroupSizeAttrName(),
convertIntegerToMDNode(llvmContext,
llvm::APInt(32, *intelReqdSubGroupSize)));
}
}

FailureOr<llvm::AttrBuilder>
ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
DictionaryAttr paramAttrs) {
Expand Down Expand Up @@ -1492,6 +1565,9 @@ LogicalResult ModuleTranslation::convertFunctionSignatures() {
// Convert function attributes.
convertFunctionAttributes(function, llvmFunc);

// Convert function kernel attributes to metadata.
convertFunctionKernelAttributes(function, llvmFunc, *this);

// Convert function_entry_count attribute to metadata.
if (std::optional<uint64_t> entryCount = function.getFunctionEntryCount())
llvmFunc->setEntryCount(entryCount.value());
Expand Down
Loading
Loading