Skip to content

Commit 6d2bbba

Browse files
authored
[MLIR][LLVM] Attach kernel metadata representation to llvm.func (#101314)
Add optional attributes to `llvm.func` representing LLVM so-called "kernel" metadata: - [`vec_type_hint`](https://clang.llvm.org/docs/AttributeReference.html#vec-type-hint) - [`work_group_size_hint`](https://clang.llvm.org/docs/AttributeReference.html#work-group-size-hint) - [`reqd_work_group_size`](https://clang.llvm.org/docs/AttributeReference.html#reqd-work-group-size) - [`intel_reqd_sub_group_size`](https://clang.llvm.org/docs/AttributeReference.html#intel-reqd-sub-group-size). --------- Signed-off-by: Victor Perez <[email protected]>
1 parent 92e18ff commit 6d2bbba

File tree

10 files changed

+383
-16
lines changed

10 files changed

+383
-16
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,6 +1071,23 @@ def LLVM_UndefAttr : LLVM_Attr<"Undef", "undef">;
10711071
/// Folded into from LLVM::PoisonOp.
10721072
def LLVM_PoisonAttr : LLVM_Attr<"Poison", "poison">;
10731073

1074+
//===----------------------------------------------------------------------===//
1075+
// VecTypeHintAttr
1076+
//===----------------------------------------------------------------------===//
1077+
1078+
def LLVM_VecTypeHintAttr : LLVM_Attr<"VecTypeHint", "vec_type_hint"> {
1079+
let summary = "Explicit vectorization compiler hint";
1080+
let description = [{
1081+
A hint to the compiler that indicates most operations used in the function
1082+
are explictly vectorized using a particular vector type. `$hint` is the
1083+
vector or scalar type in particular. `$is_signed` can be used with integer
1084+
types to state whether the type is signed.
1085+
}];
1086+
let parameters = (ins "TypeAttr":$hint,
1087+
DefaultValuedParameter<"bool", "false">:$is_signed);
1088+
let assemblyFormat = "`<` struct(params) `>`";
1089+
}
1090+
10741091
//===----------------------------------------------------------------------===//
10751092
// ZeroAttr
10761093
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1456,7 +1456,11 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
14561456
OptionalAttr<UnitAttr>:$always_inline,
14571457
OptionalAttr<UnitAttr>:$no_unwind,
14581458
OptionalAttr<UnitAttr>:$will_return,
1459-
OptionalAttr<UnitAttr>:$optimize_none
1459+
OptionalAttr<UnitAttr>:$optimize_none,
1460+
OptionalAttr<LLVM_VecTypeHintAttr>:$vec_type_hint,
1461+
OptionalAttr<DenseI32ArrayAttr>:$work_group_size_hint,
1462+
OptionalAttr<DenseI32ArrayAttr>:$reqd_work_group_size,
1463+
OptionalAttr<I32Attr>:$intel_reqd_sub_group_size
14601464
);
14611465

14621466
let regions = (region AnyRegion:$body);

mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,12 @@ class LLVMImportDialectInterface
8484

8585
/// Hook for derived dialect interfaces to publish the supported metadata
8686
/// kinds. As every metadata kind has a unique integer identifier, the
87-
/// function returns the list of supported metadata identifiers.
88-
virtual ArrayRef<unsigned> getSupportedMetadata() const { return {}; }
87+
/// function returns the list of supported metadata identifiers. `ctx` can be
88+
/// used to obtain IDs of metadata kinds that do not have a fixed static one.
89+
virtual ArrayRef<unsigned>
90+
getSupportedMetadata(llvm::LLVMContext &ctx) const {
91+
return {};
92+
}
8993
};
9094

9195
/// Interface collection for the import of LLVM IR that dispatches to a concrete
@@ -101,7 +105,7 @@ class LLVMImportInterface
101105
/// intrinsic and metadata kinds and builds the dispatch tables for the
102106
/// conversion. Returns failure if multiple dialect interfaces translate the
103107
/// same LLVM IR intrinsic.
104-
LogicalResult initializeImport() {
108+
LogicalResult initializeImport(llvm::LLVMContext &llvmContext) {
105109
for (const LLVMImportDialectInterface &iface : *this) {
106110
// Verify the supported intrinsics have not been mapped before.
107111
const auto *intrinsicIt =
@@ -139,7 +143,7 @@ class LLVMImportInterface
139143
for (unsigned id : iface.getSupportedInstructions())
140144
instructionToDialect[id] = &iface;
141145
// Add a mapping for all supported metadata kinds.
142-
for (unsigned kind : iface.getSupportedMetadata())
146+
for (unsigned kind : iface.getSupportedMetadata(llvmContext))
143147
metadataToDialect[kind].push_back(iface.getDialect());
144148
}
145149

mlir/include/mlir/Target/LLVMIR/ModuleImport.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ class ModuleImport {
5353
/// dialect interfaces for the supported LLVM IR intrinsics and metadata kinds
5454
/// and builds the dispatch tables. Returns failure if multiple dialect
5555
/// interfaces translate the same LLVM IR intrinsic.
56-
LogicalResult initializeImportInterface() { return iface.initializeImport(); }
56+
LogicalResult initializeImportInterface() {
57+
return iface.initializeImport(llvmModule->getContext());
58+
}
5759

5860
/// Converts all functions of the LLVM module to MLIR functions.
5961
LogicalResult convertFunctions();

mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp

Lines changed: 152 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ using namespace mlir::LLVM::detail;
3232

3333
#include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc"
3434

35+
static constexpr StringLiteral vecTypeHintMDName = "vec_type_hint";
36+
static constexpr StringLiteral workGroupSizeHintMDName = "work_group_size_hint";
37+
static constexpr StringLiteral reqdWorkGroupSizeMDName = "reqd_work_group_size";
38+
static constexpr StringLiteral intelReqdSubGroupSizeMDName =
39+
"intel_reqd_sub_group_size";
40+
3541
/// Returns true if the LLVM IR intrinsic is convertible to an MLIR LLVM dialect
3642
/// intrinsic. Returns false otherwise.
3743
static bool isConvertibleIntrinsic(llvm::Intrinsic::ID id) {
@@ -70,11 +76,18 @@ static LogicalResult convertIntrinsicImpl(OpBuilder &odsBuilder,
7076

7177
/// Returns the list of LLVM IR metadata kinds that are convertible to MLIR LLVM
7278
/// dialect attributes.
73-
static ArrayRef<unsigned> getSupportedMetadataImpl() {
79+
static ArrayRef<unsigned> getSupportedMetadataImpl(llvm::LLVMContext &context) {
7480
static const SmallVector<unsigned> convertibleMetadata = {
75-
llvm::LLVMContext::MD_prof, llvm::LLVMContext::MD_tbaa,
76-
llvm::LLVMContext::MD_access_group, llvm::LLVMContext::MD_loop,
77-
llvm::LLVMContext::MD_noalias, llvm::LLVMContext::MD_alias_scope};
81+
llvm::LLVMContext::MD_prof,
82+
llvm::LLVMContext::MD_tbaa,
83+
llvm::LLVMContext::MD_access_group,
84+
llvm::LLVMContext::MD_loop,
85+
llvm::LLVMContext::MD_noalias,
86+
llvm::LLVMContext::MD_alias_scope,
87+
context.getMDKindID(vecTypeHintMDName),
88+
context.getMDKindID(workGroupSizeHintMDName),
89+
context.getMDKindID(reqdWorkGroupSizeMDName),
90+
context.getMDKindID(intelReqdSubGroupSizeMDName)};
7891
return convertibleMetadata;
7992
}
8093

@@ -226,6 +239,128 @@ static LogicalResult setNoaliasScopesAttr(const llvm::MDNode *node,
226239
return success();
227240
}
228241

242+
/// Extracts an integer from the provided metadata `md` if possible. Returns
243+
/// nullopt otherwise.
244+
static std::optional<int32_t> parseIntegerMD(llvm::Metadata *md) {
245+
auto *constant = dyn_cast_if_present<llvm::ConstantAsMetadata>(md);
246+
if (!constant)
247+
return {};
248+
249+
auto *intConstant = dyn_cast<llvm::ConstantInt>(constant->getValue());
250+
if (!intConstant)
251+
return {};
252+
253+
return intConstant->getValue().getSExtValue();
254+
}
255+
256+
/// Converts the provided metadata node `node` to an LLVM dialect
257+
/// VecTypeHintAttr if possible.
258+
static VecTypeHintAttr convertVecTypeHint(Builder builder, llvm::MDNode *node,
259+
ModuleImport &moduleImport) {
260+
if (!node || node->getNumOperands() != 2)
261+
return {};
262+
263+
auto *hintMD = dyn_cast<llvm::ValueAsMetadata>(node->getOperand(0).get());
264+
if (!hintMD)
265+
return {};
266+
TypeAttr hint = TypeAttr::get(moduleImport.convertType(hintMD->getType()));
267+
268+
std::optional<int32_t> optIsSigned =
269+
parseIntegerMD(node->getOperand(1).get());
270+
if (!optIsSigned)
271+
return {};
272+
bool isSigned = *optIsSigned != 0;
273+
274+
return builder.getAttr<VecTypeHintAttr>(hint, isSigned);
275+
}
276+
277+
/// Converts the provided metadata node `node` to an MLIR DenseI32ArrayAttr if
278+
/// possible.
279+
static DenseI32ArrayAttr convertDenseI32Array(Builder builder,
280+
llvm::MDNode *node) {
281+
if (!node)
282+
return {};
283+
SmallVector<int32_t> vals;
284+
for (const llvm::MDOperand &op : node->operands()) {
285+
std::optional<int32_t> mdValue = parseIntegerMD(op.get());
286+
if (!mdValue)
287+
return {};
288+
vals.push_back(*mdValue);
289+
}
290+
return builder.getDenseI32ArrayAttr(vals);
291+
}
292+
293+
/// Convert an `MDNode` to an MLIR `IntegerAttr` if possible.
294+
static IntegerAttr convertIntegerMD(Builder builder, llvm::MDNode *node) {
295+
if (!node || node->getNumOperands() != 1)
296+
return {};
297+
std::optional<int32_t> val = parseIntegerMD(node->getOperand(0));
298+
if (!val)
299+
return {};
300+
return builder.getI32IntegerAttr(*val);
301+
}
302+
303+
static LogicalResult setVecTypeHintAttr(Builder &builder, llvm::MDNode *node,
304+
Operation *op,
305+
LLVM::ModuleImport &moduleImport) {
306+
auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
307+
if (!funcOp)
308+
return failure();
309+
310+
VecTypeHintAttr attr = convertVecTypeHint(builder, node, moduleImport);
311+
if (!attr)
312+
return failure();
313+
314+
funcOp.setVecTypeHintAttr(attr);
315+
return success();
316+
}
317+
318+
static LogicalResult
319+
setWorkGroupSizeHintAttr(Builder &builder, llvm::MDNode *node, Operation *op) {
320+
auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
321+
if (!funcOp)
322+
return failure();
323+
324+
DenseI32ArrayAttr attr = convertDenseI32Array(builder, node);
325+
if (!attr)
326+
return failure();
327+
328+
funcOp.setWorkGroupSizeHintAttr(attr);
329+
return success();
330+
}
331+
332+
static LogicalResult
333+
setReqdWorkGroupSizeAttr(Builder &builder, llvm::MDNode *node, Operation *op) {
334+
auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
335+
if (!funcOp)
336+
return failure();
337+
338+
DenseI32ArrayAttr attr = convertDenseI32Array(builder, node);
339+
if (!attr)
340+
return failure();
341+
342+
funcOp.setReqdWorkGroupSizeAttr(attr);
343+
return success();
344+
}
345+
346+
/// Converts the given intel required subgroup size metadata node to an MLIR
347+
/// attribute and attaches it to the imported operation if the translation
348+
/// succeeds. Returns failure otherwise.
349+
static LogicalResult setIntelReqdSubGroupSizeAttr(Builder &builder,
350+
llvm::MDNode *node,
351+
Operation *op) {
352+
auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
353+
if (!funcOp)
354+
return failure();
355+
356+
IntegerAttr attr = convertIntegerMD(builder, node);
357+
if (!attr)
358+
return failure();
359+
360+
funcOp.setIntelReqdSubGroupSizeAttr(attr);
361+
return success();
362+
}
363+
229364
namespace {
230365

231366
/// Implementation of the dialect interface that converts operations belonging
@@ -261,6 +396,16 @@ class LLVMDialectLLVMIRImportInterface : public LLVMImportDialectInterface {
261396
if (kind == llvm::LLVMContext::MD_noalias)
262397
return setNoaliasScopesAttr(node, op, moduleImport);
263398

399+
llvm::LLVMContext &context = node->getContext();
400+
if (kind == context.getMDKindID(vecTypeHintMDName))
401+
return setVecTypeHintAttr(builder, node, op, moduleImport);
402+
if (kind == context.getMDKindID(workGroupSizeHintMDName))
403+
return setWorkGroupSizeHintAttr(builder, node, op);
404+
if (kind == context.getMDKindID(reqdWorkGroupSizeMDName))
405+
return setReqdWorkGroupSizeAttr(builder, node, op);
406+
if (kind == context.getMDKindID(intelReqdSubGroupSizeMDName))
407+
return setIntelReqdSubGroupSizeAttr(builder, node, op);
408+
264409
// A handler for a supported metadata kind is missing.
265410
llvm_unreachable("unknown metadata type");
266411
}
@@ -273,8 +418,9 @@ class LLVMDialectLLVMIRImportInterface : public LLVMImportDialectInterface {
273418

274419
/// Returns the list of LLVM IR metadata kinds that are convertible to MLIR
275420
/// LLVM dialect attributes.
276-
ArrayRef<unsigned> getSupportedMetadata() const final {
277-
return getSupportedMetadataImpl();
421+
ArrayRef<unsigned>
422+
getSupportedMetadata(llvm::LLVMContext &context) const final {
423+
return getSupportedMetadataImpl(context);
278424
}
279425
};
280426
} // namespace

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -500,10 +500,10 @@ LogicalResult ModuleImport::convertLinkerOptionsMetadata() {
500500
if (named.getName() != "llvm.linker.options")
501501
continue;
502502
// llvm.linker.options operands are lists of strings.
503-
for (const llvm::MDNode *md : named.operands()) {
503+
for (const llvm::MDNode *node : named.operands()) {
504504
SmallVector<StringRef> options;
505-
options.reserve(md->getNumOperands());
506-
for (const llvm::MDOperand &option : md->operands())
505+
options.reserve(node->getNumOperands());
506+
for (const llvm::MDOperand &option : node->operands())
507507
options.push_back(cast<llvm::MDString>(option)->getString());
508508
builder.create<LLVM::LinkerOptionsOp>(mlirModule.getLoc(),
509509
builder.getStrArrayAttr(options));

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,6 +1247,41 @@ static LogicalResult checkedAddLLVMFnAttribute(Location loc,
12471247
return success();
12481248
}
12491249

1250+
/// Return a representation of `value` as metadata.
1251+
static llvm::Metadata *convertIntegerToMetadata(llvm::LLVMContext &context,
1252+
const llvm::APInt &value) {
1253+
llvm::Constant *constant = llvm::ConstantInt::get(context, value);
1254+
return llvm::ConstantAsMetadata::get(constant);
1255+
}
1256+
1257+
/// Return a representation of `value` as an MDNode.
1258+
static llvm::MDNode *convertIntegerToMDNode(llvm::LLVMContext &context,
1259+
const llvm::APInt &value) {
1260+
return llvm::MDNode::get(context, convertIntegerToMetadata(context, value));
1261+
}
1262+
1263+
/// Return an MDNode encoding `vec_type_hint` metadata.
1264+
static llvm::MDNode *convertVecTypeHintToMDNode(llvm::LLVMContext &context,
1265+
llvm::Type *type,
1266+
bool isSigned) {
1267+
llvm::Metadata *typeMD =
1268+
llvm::ConstantAsMetadata::get(llvm::UndefValue::get(type));
1269+
llvm::Metadata *isSignedMD =
1270+
convertIntegerToMetadata(context, llvm::APInt(32, isSigned ? 1 : 0));
1271+
return llvm::MDNode::get(context, {typeMD, isSignedMD});
1272+
}
1273+
1274+
/// Return an MDNode with a tuple given by the values in `values`.
1275+
static llvm::MDNode *convertIntegerArrayToMDNode(llvm::LLVMContext &context,
1276+
ArrayRef<int32_t> values) {
1277+
SmallVector<llvm::Metadata *> mdValues;
1278+
llvm::transform(
1279+
values, std::back_inserter(mdValues), [&context](int32_t value) {
1280+
return convertIntegerToMetadata(context, llvm::APInt(32, value));
1281+
});
1282+
return llvm::MDNode::get(context, mdValues);
1283+
}
1284+
12501285
/// Attaches the attributes listed in the given array attribute to `llvmFunc`.
12511286
/// Reports error to `loc` if any and returns immediately. Expects `attributes`
12521287
/// to be an array attribute containing either string attributes, treated as
@@ -1448,6 +1483,44 @@ static void convertFunctionAttributes(LLVMFuncOp func,
14481483
convertFunctionMemoryAttributes(func, llvmFunc);
14491484
}
14501485

1486+
/// Converts function attributes from `func` and attaches them to `llvmFunc`.
1487+
static void convertFunctionKernelAttributes(LLVMFuncOp func,
1488+
llvm::Function *llvmFunc,
1489+
ModuleTranslation &translation) {
1490+
llvm::LLVMContext &llvmContext = llvmFunc->getContext();
1491+
1492+
if (VecTypeHintAttr vecTypeHint = func.getVecTypeHintAttr()) {
1493+
Type type = vecTypeHint.getHint().getValue();
1494+
llvm::Type *llvmType = translation.convertType(type);
1495+
bool isSigned = vecTypeHint.getIsSigned();
1496+
llvmFunc->setMetadata(
1497+
func.getVecTypeHintAttrName(),
1498+
convertVecTypeHintToMDNode(llvmContext, llvmType, isSigned));
1499+
}
1500+
1501+
if (std::optional<ArrayRef<int32_t>> workGroupSizeHint =
1502+
func.getWorkGroupSizeHint()) {
1503+
llvmFunc->setMetadata(
1504+
func.getWorkGroupSizeHintAttrName(),
1505+
convertIntegerArrayToMDNode(llvmContext, *workGroupSizeHint));
1506+
}
1507+
1508+
if (std::optional<ArrayRef<int32_t>> reqdWorkGroupSize =
1509+
func.getReqdWorkGroupSize()) {
1510+
llvmFunc->setMetadata(
1511+
func.getReqdWorkGroupSizeAttrName(),
1512+
convertIntegerArrayToMDNode(llvmContext, *reqdWorkGroupSize));
1513+
}
1514+
1515+
if (std::optional<uint32_t> intelReqdSubGroupSize =
1516+
func.getIntelReqdSubGroupSize()) {
1517+
llvmFunc->setMetadata(
1518+
func.getIntelReqdSubGroupSizeAttrName(),
1519+
convertIntegerToMDNode(llvmContext,
1520+
llvm::APInt(32, *intelReqdSubGroupSize)));
1521+
}
1522+
}
1523+
14511524
FailureOr<llvm::AttrBuilder>
14521525
ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
14531526
DictionaryAttr paramAttrs) {
@@ -1492,6 +1565,9 @@ LogicalResult ModuleTranslation::convertFunctionSignatures() {
14921565
// Convert function attributes.
14931566
convertFunctionAttributes(function, llvmFunc);
14941567

1568+
// Convert function kernel attributes to metadata.
1569+
convertFunctionKernelAttributes(function, llvmFunc, *this);
1570+
14951571
// Convert function_entry_count attribute to metadata.
14961572
if (std::optional<uint64_t> entryCount = function.getFunctionEntryCount())
14971573
llvmFunc->setEntryCount(entryCount.value());

0 commit comments

Comments
 (0)