-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][LLVM] Attach kernel metadata representation to llvm.func
#101314
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Add optional attributes to `llvm.func` representing LLVM so-called "kernel" metadata: - [`vec_type_hint`](https://clang.llvm.org/docs/AttributeReference.html#vec-type-hint) - [`work_group_size_hint`](https://clang.llvm.org/docs/AttributeReference.html#work-group-size-hint) - [`reqd_work_group_size`](https://clang.llvm.org/docs/AttributeReference.html#reqd-work-group-size) - [`intel_reqd_sub_group_size`](https://clang.llvm.org/docs/AttributeReference.html#intel-reqd-sub-group-size). Signed-off-by: Victor Perez <[email protected]>
@llvm/pr-subscribers-mlir-llvm Author: Victor Perez (victor-eds) ChangesAdd optional attributes to Full diff: https://github.com/llvm/llvm-project/pull/101314.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 695a962bcab9b..0ecd2dcacffc1 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1071,6 +1071,17 @@ def LLVM_UndefAttr : LLVM_Attr<"Undef", "undef">;
/// Folded into from LLVM::PoisonOp.
def LLVM_PoisonAttr : LLVM_Attr<"Poison", "poison">;
+//===----------------------------------------------------------------------===//
+// VecTypeHintAttr
+//===----------------------------------------------------------------------===//
+
+/// Represents "vec_type_hint" values
+def LLVM_VecTypeHintAttr : LLVM_Attr<"VecTypeHint", "vec_type_hint"> {
+ let parameters = (ins "TypeAttr":$hint,
+ DefaultValuedParameter<"bool", "false">:$is_signed);
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
//===----------------------------------------------------------------------===//
// ZeroAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 260d42185b57f..fde42682d807a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1456,7 +1456,12 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
OptionalAttr<UnitAttr>:$always_inline,
OptionalAttr<UnitAttr>:$no_unwind,
OptionalAttr<UnitAttr>:$will_return,
- OptionalAttr<UnitAttr>:$optimize_none
+ OptionalAttr<UnitAttr>:$optimize_none,
+ // Kernel metadata
+ OptionalAttr<LLVM_VecTypeHintAttr>:$vec_type_hint,
+ OptionalAttr<DenseI32ArrayAttr>:$work_group_size_hint,
+ OptionalAttr<DenseI32ArrayAttr>:$reqd_work_group_size,
+ OptionalAttr<I32Attr>:$intel_reqd_sub_group_size
);
let regions = (region AnyRegion:$body);
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 8b40b7b2df6c7..cb18dc5193352 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1918,6 +1918,99 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func,
builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder)));
}
+/// Extract constant integer value from metadata if this is constant. Return
+/// `std::nullopt` otherwise.
+static std::optional<int32_t> parseIntegerMD(llvm::Metadata *md) {
+ if (!md)
+ return {};
+
+ auto *c = dyn_cast<llvm::ConstantAsMetadata>(md);
+ if (!c)
+ return {};
+
+ auto *ci = dyn_cast<llvm::ConstantInt>(c->getValue());
+ if (!ci)
+ return {};
+
+ return ci->getValue().getSExtValue();
+}
+
+/// Convert an `MDNode` to an LLVM dialect `VecTypeHintAttr` if possible.
+template <typename ConvertType>
+static VecTypeHintAttr convertVecTypeHint(Builder builder, llvm::MDNode *md,
+ ConvertType convertType) {
+ if (!md || md->getNumOperands() != 2)
+ return {};
+
+ auto *hintMD = dyn_cast<llvm::ValueAsMetadata>(md->getOperand(0).get());
+ if (!hintMD)
+ return {};
+ TypeAttr hint = TypeAttr::get(convertType(hintMD->getType()));
+
+ std::optional<int32_t> optIsSigned = parseIntegerMD(md->getOperand(1).get());
+ if (!optIsSigned)
+ return {};
+ bool isSigned = *optIsSigned != 0;
+
+ return builder.getAttr<VecTypeHintAttr>(hint, isSigned);
+}
+
+/// Convert an `MDNode` to an MLIR `DenseI32ArrayAttr` if possible.
+static DenseI32ArrayAttr convertDenseI32Array(Builder builder,
+ llvm::MDNode *md) {
+ if (!md)
+ return {};
+ SmallVector<int32_t> vals;
+ for (const llvm::MDOperand &op : md->operands()) {
+ std::optional<int32_t> mdValue = parseIntegerMD(op.get());
+ if (!mdValue)
+ return {};
+ vals.push_back(*mdValue);
+ }
+ return builder.getDenseI32ArrayAttr(vals);
+}
+
+/// Convert an `MDNode` to an MLIR `IntegerAttr` if possible.
+static IntegerAttr convertIntegerMD(Builder builder, llvm::MDNode *md) {
+ if (!md || md->getNumOperands() != 1)
+ return {};
+ std::optional<int32_t> val = parseIntegerMD(md->getOperand(0));
+ if (!val)
+ return {};
+ return builder.getI32IntegerAttr(*val);
+}
+
+/// Process metadata found in kernel functions:
+/// - `vec_type_hint`
+/// - `work_group_size_hint`
+/// - `reqd_work_group_size`
+/// - `intel_reqd_sub_group_size`
+template <typename ConvertType>
+static void processKernelMetadata(llvm::Function *func, LLVMFuncOp funcOp,
+ ConvertType convertType) {
+ Builder builder(funcOp);
+
+ if (VecTypeHintAttr attr = convertVecTypeHint(
+ builder, func->getMetadata("vec_type_hint"), convertType)) {
+ funcOp.setVecTypeHintAttr(attr);
+ }
+
+ if (DenseI32ArrayAttr attr = convertDenseI32Array(
+ builder, func->getMetadata("work_group_size_hint"))) {
+ funcOp.setWorkGroupSizeHintAttr(attr);
+ }
+
+ if (DenseI32ArrayAttr attr = convertDenseI32Array(
+ builder, func->getMetadata("reqd_work_group_size"))) {
+ funcOp.setReqdWorkGroupSizeAttr(attr);
+ }
+
+ if (IntegerAttr attr = convertIntegerMD(
+ builder, func->getMetadata("intel_reqd_sub_group_size"))) {
+ funcOp.setIntelReqdSubGroupSizeAttr(attr);
+ }
+}
+
LogicalResult ModuleImport::processFunction(llvm::Function *func) {
clearRegionState();
@@ -1966,6 +2059,10 @@ LogicalResult ModuleImport::processFunction(llvm::Function *func) {
// Handle Function attributes.
processFunctionAttributes(func, funcOp);
+ // Handle Kernel Metadata
+ processKernelMetadata(func, funcOp,
+ [this](llvm::Type *type) { return convertType(type); });
+
// Convert non-debug metadata by using the dialect interface.
SmallVector<std::pair<unsigned, llvm::MDNode *>> allMetadata;
func->getAllMetadata(allMetadata);
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 3016d1846e00f..00ca8bdbf0c23 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1247,6 +1247,41 @@ static LogicalResult checkedAddLLVMFnAttribute(Location loc,
return success();
}
+/// Return a representation of `value` as metadata.
+static llvm::Metadata *convertIntegerToMetadata(llvm::LLVMContext &context,
+ const llvm::APInt &value) {
+ llvm::Constant *constant = llvm::ConstantInt::get(context, value);
+ return llvm::ConstantAsMetadata::get(constant);
+}
+
+/// Return a representation of `value` as an MDNode.
+static llvm::MDNode *convertIntegerToMDNode(llvm::LLVMContext &context,
+ const llvm::APInt &value) {
+ return llvm::MDNode::get(context, convertIntegerToMetadata(context, value));
+}
+
+/// Return an MDNode encoding `vec_type_hint` metadata.
+static llvm::MDNode *convertVecTypeHintToMDNode(llvm::LLVMContext &context,
+ llvm::Type *type,
+ bool isSigned) {
+ llvm::Metadata *typeMD =
+ llvm::ConstantAsMetadata::get(llvm::UndefValue::get(type));
+ llvm::Metadata *isSignedMD =
+ convertIntegerToMetadata(context, llvm::APInt(32, isSigned ? 1 : 0));
+ return llvm::MDNode::get(context, {typeMD, isSignedMD});
+}
+
+/// Return an MDNode with a tuple given by the values in the input integer array
+/// attribute.
+static llvm::MDNode *convertIntegerArrayToMDNode(llvm::LLVMContext &context,
+ ArrayRef<int32_t> values) {
+ llvm::SmallVector<llvm::Metadata *> mds;
+ llvm::transform(values, std::back_inserter(mds), [&context](int32_t value) {
+ return convertIntegerToMetadata(context, llvm::APInt(32, value));
+ });
+ return llvm::MDNode::get(context, mds);
+}
+
/// Attaches the attributes listed in the given array attribute to `llvmFunc`.
/// Reports error to `loc` if any and returns immediately. Expects `attributes`
/// to be an array attribute containing either string attributes, treated as
@@ -1448,6 +1483,42 @@ static void convertFunctionAttributes(LLVMFuncOp func,
convertFunctionMemoryAttributes(func, llvmFunc);
}
+/// Converts function attributes from `func` and attaches them to `llvmFunc`.
+template <typename TypeConverter>
+static void convertFunctionKernelAttributes(LLVMFuncOp func,
+ llvm::Function *llvmFunc,
+ TypeConverter convertType) {
+ llvm::LLVMContext &llvmContext = llvmFunc->getContext();
+
+ if (auto vecTypeHint = func.getVecTypeHint()) {
+ Type type = vecTypeHint->getHint().getValue();
+ llvm::Type *llvmType = convertType(type);
+ bool isSigned = vecTypeHint->getIsSigned();
+ llvmFunc->setMetadata(
+ "vec_type_hint",
+ convertVecTypeHintToMDNode(llvmContext, llvmType, isSigned));
+ }
+
+ if (auto workGroupSizeHint = func.getWorkGroupSizeHint()) {
+ llvmFunc->setMetadata(
+ "work_group_size_hint",
+ convertIntegerArrayToMDNode(llvmContext, *workGroupSizeHint));
+ }
+
+ if (auto reqdWorkGroupSize = func.getReqdWorkGroupSize()) {
+ llvmFunc->setMetadata(
+ "reqd_work_group_size",
+ convertIntegerArrayToMDNode(llvmContext, *reqdWorkGroupSize));
+ }
+
+ if (auto intelReqdSubGroupSize = func.getIntelReqdSubGroupSize()) {
+ llvmFunc->setMetadata(
+ "intel_reqd_sub_group_size",
+ convertIntegerToMDNode(llvmContext,
+ llvm::APInt(32, *intelReqdSubGroupSize)));
+ }
+}
+
FailureOr<llvm::AttrBuilder>
ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
DictionaryAttr paramAttrs) {
@@ -1492,6 +1563,10 @@ LogicalResult ModuleTranslation::convertFunctionSignatures() {
// Convert function attributes.
convertFunctionAttributes(function, llvmFunc);
+ // Convert function kernel attributes to metadata
+ convertFunctionKernelAttributes(
+ function, llvmFunc, [this](Type type) { return convertType(type); });
+
// Convert function_entry_count attribute to metadata.
if (std::optional<uint64_t> entryCount = function.getFunctionEntryCount())
llvmFunc->setEntryCount(entryCount.value());
diff --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir
index 0e29a548de72f..40b4e49f08a3e 100644
--- a/mlir/test/Dialect/LLVMIR/func.mlir
+++ b/mlir/test/Dialect/LLVMIR/func.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s
// RUN: mlir-opt -split-input-file -verify-diagnostics -mlir-print-op-generic %s | FileCheck %s --check-prefix=GENERIC
-// RUN: mlir-opt -split-input-file -verify-diagnostics %s -mlir-print-debuginfo | mlir-opt -mlir-print-debuginfo | FileCheck %s --check-prefix=LOCINFO
+// RUN: mlir-opt -split-input-file -verify-diagnostics -mlir-print-debuginfo %s | mlir-opt -split-input-file -mlir-print-debuginfo | FileCheck %s --check-prefix=LOCINFO
// RUN: mlir-translate -mlir-to-llvmir -split-input-file -verify-diagnostics %s | FileCheck %s --check-prefix=CHECK-LLVM
module {
@@ -432,3 +432,43 @@ module {
// expected-error @below {{failed to parse CConvAttr parameter 'CallingConv' which is to be a `CConv`}}
}) {sym_name = "generic_unknown_calling_convention", CConv = #llvm.cconv<cc_12>, function_type = !llvm.func<i64 (i64, i64)>} : () -> ()
}
+
+// -----
+
+// CHECK: @vec_type_hint()
+// CHECK-SAME: vec_type_hint = #llvm.vec_type_hint<hint = i32>
+llvm.func @vec_type_hint() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32>}
+
+// CHECK: @vec_type_hint_signed()
+// CHECK-SAME: vec_type_hint = #llvm.vec_type_hint<hint = i32, is_signed = true>
+llvm.func @vec_type_hint_signed() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32, is_signed = true>}
+
+// CHECK: @vec_type_hint_signed_vec()
+// CHECK-SAME: vec_type_hint = #llvm.vec_type_hint<hint = vector<2xi32>, is_signed = true>
+llvm.func @vec_type_hint_signed_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<2xi32>, is_signed = true>}
+
+// CHECK: @vec_type_hint_float_vec()
+// CHECK-SAME: vec_type_hint = #llvm.vec_type_hint<hint = vector<3xf32>>
+llvm.func @vec_type_hint_float_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<3xf32>>}
+
+// CHECK: @vec_type_hint_bfloat_vec()
+// CHECK-SAME: vec_type_hint = #llvm.vec_type_hint<hint = vector<8xbf16>>
+llvm.func @vec_type_hint_bfloat_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<8xbf16>>}
+
+// -----
+
+// CHECK: @work_group_size_hint()
+// CHECK-SAME: work_group_size_hint = array<i32: 128, 128, 128>
+llvm.func @work_group_size_hint() attributes {work_group_size_hint = array<i32: 128, 128, 128>}
+
+// -----
+
+// CHECK: @reqd_work_group_size_hint()
+// CHECK-SAME: reqd_work_group_size = array<i32: 128, 256, 128>
+llvm.func @reqd_work_group_size_hint() attributes {reqd_work_group_size = array<i32: 128, 256, 128>}
+
+// -----
+
+// CHECK: @intel_reqd_sub_group_size_hint()
+// CHECK-SAME: intel_reqd_sub_group_size = 32 : i32
+llvm.func @intel_reqd_sub_group_size_hint() attributes {llvm.intel_reqd_sub_group_size = 32 : i32}
diff --git a/mlir/test/Target/LLVMIR/Import/metadata-kernel.ll b/mlir/test/Target/LLVMIR/Import/metadata-kernel.ll
new file mode 100644
index 0000000000000..963e40348c795
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/metadata-kernel.ll
@@ -0,0 +1,34 @@
+; RUN: mlir-translate -import-llvm %s | FileCheck %s
+
+; CHECK: llvm.func @vec_type_hint() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32>}
+declare !vec_type_hint !0 void @vec_type_hint()
+
+; CHECK: llvm.func @vec_type_hint_signed() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32, is_signed = true>}
+declare !vec_type_hint !1 void @vec_type_hint_signed()
+
+; CHECK: llvm.func @vec_type_hint_signed_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<2xi32>, is_signed = true>}
+declare !vec_type_hint !2 void @vec_type_hint_signed_vec()
+
+; CHECK: llvm.func @vec_type_hint_float_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<3xf32>>}
+declare !vec_type_hint !3 void @vec_type_hint_float_vec()
+
+; CHECK: llvm.func @vec_type_hint_bfloat_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<8xbf16>>}
+declare !vec_type_hint !4 void @vec_type_hint_bfloat_vec()
+
+; CHECK: llvm.func @work_group_size_hint() attributes {work_group_size_hint = array<i32: 128, 128, 128>}
+declare !work_group_size_hint !5 void @work_group_size_hint()
+
+; CHECK: llvm.func @reqd_work_group_size() attributes {reqd_work_group_size = array<i32: 128, 256, 128>}
+declare !reqd_work_group_size !6 void @reqd_work_group_size()
+
+; CHECK: llvm.func @intel_reqd_sub_group_size() attributes {intel_reqd_sub_group_size = 32 : i32}
+declare !intel_reqd_sub_group_size !7 void @intel_reqd_sub_group_size()
+
+!0 = !{i32 undef, i32 0}
+!1 = !{i32 undef, i32 1}
+!2 = !{<2 x i32> undef, i32 1}
+!3 = !{<3 x float> undef, i32 0}
+!4 = !{<8 x bfloat> undef, i32 0}
+!5 = !{i32 128, i32 128, i32 128}
+!6 = !{i32 128, i32 256, i32 128}
+!7 = !{i32 32}
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index db54d131299c6..8c2efa3d16f6b 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -2549,3 +2549,47 @@ llvm.func @mem_effects_call() {
// CHECK-SAME: memory(read, inaccessiblemem: write)
// CHECK: #[[ATTRS_3]]
// CHECK-SAME: memory(readwrite, argmem: read)
+
+// -----
+
+// CHECK: declare !vec_type_hint ![[#VEC_TYPE_HINT:]] void @vec_type_hint()
+llvm.func @vec_type_hint() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32>}
+
+// CHECK: declare !vec_type_hint ![[#VEC_TYPE_HINT_SIGNED:]] void @vec_type_hint_signed()
+llvm.func @vec_type_hint_signed() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32, is_signed = true>}
+
+// CHECK: declare !vec_type_hint ![[#VEC_TYPE_HINT_SIGNED_VEC:]] void @vec_type_hint_signed_vec()
+llvm.func @vec_type_hint_signed_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<2xi32>, is_signed = true>}
+
+// CHECK: declare !vec_type_hint ![[#VEC_TYPE_HINT_FVEC:]] void @vec_type_hint_float_vec()
+llvm.func @vec_type_hint_float_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<3xf32>>}
+
+// CHECK: declare !vec_type_hint ![[#VEC_TYPE_HINT_BFVEC:]] void @vec_type_hint_bfloat_vec()
+llvm.func @vec_type_hint_bfloat_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<8xbf16>>}
+
+// CHECK: ![[#VEC_TYPE_HINT]] = !{i32 undef, i32 0}
+// CHECK: ![[#VEC_TYPE_HINT_SIGNED]] = !{i32 undef, i32 1}
+// CHECK: ![[#VEC_TYPE_HINT_SIGNED_VEC]] = !{<2 x i32> undef, i32 1}
+// CHECK: ![[#VEC_TYPE_HINT_FVEC]] = !{<3 x float> undef, i32 0}
+// CHECK: ![[#VEC_TYPE_HINT_BFVEC]] = !{<8 x bfloat> undef, i32 0}
+
+// -----
+
+// CHECK: declare !work_group_size_hint ![[#WORK_GROUP_SIZE_HINT:]] void @work_group_size_hint()
+llvm.func @work_group_size_hint() attributes {work_group_size_hint = array<i32: 128, 128, 128>}
+
+// CHECK: ![[#WORK_GROUP_SIZE_HINT]] = !{i32 128, i32 128, i32 128}
+
+// -----
+
+// CHECK: declare !reqd_work_group_size ![[#REQD_WORK_GROUP_SIZE:]] void @reqd_work_group_size()
+llvm.func @reqd_work_group_size() attributes {reqd_work_group_size = array<i32: 128, 256, 128>}
+
+// CHECK: ![[#REQD_WORK_GROUP_SIZE]] = !{i32 128, i32 256, i32 128}
+
+// -----
+
+// CHECK: declare !intel_reqd_sub_group_size ![[#INTEL_REQD_SUB_GROUP_SIZE:]] void @intel_reqd_sub_group_size()
+llvm.func @intel_reqd_sub_group_size() attributes {intel_reqd_sub_group_size = 32 : i32}
+
+// CHECK: ![[#INTEL_REQD_SUB_GROUP_SIZE]] = !{i32 32}
|
@llvm/pr-subscribers-mlir Author: Victor Perez (victor-eds) ChangesAdd optional attributes to Full diff: https://github.com/llvm/llvm-project/pull/101314.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 695a962bcab9b..0ecd2dcacffc1 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1071,6 +1071,17 @@ def LLVM_UndefAttr : LLVM_Attr<"Undef", "undef">;
/// Folded into from LLVM::PoisonOp.
def LLVM_PoisonAttr : LLVM_Attr<"Poison", "poison">;
+//===----------------------------------------------------------------------===//
+// VecTypeHintAttr
+//===----------------------------------------------------------------------===//
+
+/// Represents "vec_type_hint" values
+def LLVM_VecTypeHintAttr : LLVM_Attr<"VecTypeHint", "vec_type_hint"> {
+ let parameters = (ins "TypeAttr":$hint,
+ DefaultValuedParameter<"bool", "false">:$is_signed);
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
//===----------------------------------------------------------------------===//
// ZeroAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 260d42185b57f..fde42682d807a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1456,7 +1456,12 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
OptionalAttr<UnitAttr>:$always_inline,
OptionalAttr<UnitAttr>:$no_unwind,
OptionalAttr<UnitAttr>:$will_return,
- OptionalAttr<UnitAttr>:$optimize_none
+ OptionalAttr<UnitAttr>:$optimize_none,
+ // Kernel metadata
+ OptionalAttr<LLVM_VecTypeHintAttr>:$vec_type_hint,
+ OptionalAttr<DenseI32ArrayAttr>:$work_group_size_hint,
+ OptionalAttr<DenseI32ArrayAttr>:$reqd_work_group_size,
+ OptionalAttr<I32Attr>:$intel_reqd_sub_group_size
);
let regions = (region AnyRegion:$body);
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 8b40b7b2df6c7..cb18dc5193352 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1918,6 +1918,99 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func,
builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder)));
}
+/// Extract constant integer value from metadata if this is constant. Return
+/// `std::nullopt` otherwise.
+static std::optional<int32_t> parseIntegerMD(llvm::Metadata *md) {
+ if (!md)
+ return {};
+
+ auto *c = dyn_cast<llvm::ConstantAsMetadata>(md);
+ if (!c)
+ return {};
+
+ auto *ci = dyn_cast<llvm::ConstantInt>(c->getValue());
+ if (!ci)
+ return {};
+
+ return ci->getValue().getSExtValue();
+}
+
+/// Convert an `MDNode` to an LLVM dialect `VecTypeHintAttr` if possible.
+template <typename ConvertType>
+static VecTypeHintAttr convertVecTypeHint(Builder builder, llvm::MDNode *md,
+ ConvertType convertType) {
+ if (!md || md->getNumOperands() != 2)
+ return {};
+
+ auto *hintMD = dyn_cast<llvm::ValueAsMetadata>(md->getOperand(0).get());
+ if (!hintMD)
+ return {};
+ TypeAttr hint = TypeAttr::get(convertType(hintMD->getType()));
+
+ std::optional<int32_t> optIsSigned = parseIntegerMD(md->getOperand(1).get());
+ if (!optIsSigned)
+ return {};
+ bool isSigned = *optIsSigned != 0;
+
+ return builder.getAttr<VecTypeHintAttr>(hint, isSigned);
+}
+
+/// Convert an `MDNode` to an MLIR `DenseI32ArrayAttr` if possible.
+static DenseI32ArrayAttr convertDenseI32Array(Builder builder,
+ llvm::MDNode *md) {
+ if (!md)
+ return {};
+ SmallVector<int32_t> vals;
+ for (const llvm::MDOperand &op : md->operands()) {
+ std::optional<int32_t> mdValue = parseIntegerMD(op.get());
+ if (!mdValue)
+ return {};
+ vals.push_back(*mdValue);
+ }
+ return builder.getDenseI32ArrayAttr(vals);
+}
+
+/// Convert an `MDNode` to an MLIR `IntegerAttr` if possible.
+static IntegerAttr convertIntegerMD(Builder builder, llvm::MDNode *md) {
+ if (!md || md->getNumOperands() != 1)
+ return {};
+ std::optional<int32_t> val = parseIntegerMD(md->getOperand(0));
+ if (!val)
+ return {};
+ return builder.getI32IntegerAttr(*val);
+}
+
+/// Process metadata found in kernel functions:
+/// - `vec_type_hint`
+/// - `work_group_size_hint`
+/// - `reqd_work_group_size`
+/// - `intel_reqd_sub_group_size`
+template <typename ConvertType>
+static void processKernelMetadata(llvm::Function *func, LLVMFuncOp funcOp,
+ ConvertType convertType) {
+ Builder builder(funcOp);
+
+ if (VecTypeHintAttr attr = convertVecTypeHint(
+ builder, func->getMetadata("vec_type_hint"), convertType)) {
+ funcOp.setVecTypeHintAttr(attr);
+ }
+
+ if (DenseI32ArrayAttr attr = convertDenseI32Array(
+ builder, func->getMetadata("work_group_size_hint"))) {
+ funcOp.setWorkGroupSizeHintAttr(attr);
+ }
+
+ if (DenseI32ArrayAttr attr = convertDenseI32Array(
+ builder, func->getMetadata("reqd_work_group_size"))) {
+ funcOp.setReqdWorkGroupSizeAttr(attr);
+ }
+
+ if (IntegerAttr attr = convertIntegerMD(
+ builder, func->getMetadata("intel_reqd_sub_group_size"))) {
+ funcOp.setIntelReqdSubGroupSizeAttr(attr);
+ }
+}
+
LogicalResult ModuleImport::processFunction(llvm::Function *func) {
clearRegionState();
@@ -1966,6 +2059,10 @@ LogicalResult ModuleImport::processFunction(llvm::Function *func) {
// Handle Function attributes.
processFunctionAttributes(func, funcOp);
+ // Handle Kernel Metadata
+ processKernelMetadata(func, funcOp,
+ [this](llvm::Type *type) { return convertType(type); });
+
// Convert non-debug metadata by using the dialect interface.
SmallVector<std::pair<unsigned, llvm::MDNode *>> allMetadata;
func->getAllMetadata(allMetadata);
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 3016d1846e00f..00ca8bdbf0c23 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1247,6 +1247,41 @@ static LogicalResult checkedAddLLVMFnAttribute(Location loc,
return success();
}
+/// Return a representation of `value` as metadata.
+static llvm::Metadata *convertIntegerToMetadata(llvm::LLVMContext &context,
+ const llvm::APInt &value) {
+ llvm::Constant *constant = llvm::ConstantInt::get(context, value);
+ return llvm::ConstantAsMetadata::get(constant);
+}
+
+/// Return a representation of `value` as an MDNode.
+static llvm::MDNode *convertIntegerToMDNode(llvm::LLVMContext &context,
+ const llvm::APInt &value) {
+ return llvm::MDNode::get(context, convertIntegerToMetadata(context, value));
+}
+
+/// Return an MDNode encoding `vec_type_hint` metadata.
+static llvm::MDNode *convertVecTypeHintToMDNode(llvm::LLVMContext &context,
+ llvm::Type *type,
+ bool isSigned) {
+ llvm::Metadata *typeMD =
+ llvm::ConstantAsMetadata::get(llvm::UndefValue::get(type));
+ llvm::Metadata *isSignedMD =
+ convertIntegerToMetadata(context, llvm::APInt(32, isSigned ? 1 : 0));
+ return llvm::MDNode::get(context, {typeMD, isSignedMD});
+}
+
+/// Return an MDNode with a tuple given by the values in the input integer array
+/// attribute.
+static llvm::MDNode *convertIntegerArrayToMDNode(llvm::LLVMContext &context,
+ ArrayRef<int32_t> values) {
+ llvm::SmallVector<llvm::Metadata *> mds;
+ llvm::transform(values, std::back_inserter(mds), [&context](int32_t value) {
+ return convertIntegerToMetadata(context, llvm::APInt(32, value));
+ });
+ return llvm::MDNode::get(context, mds);
+}
+
/// Attaches the attributes listed in the given array attribute to `llvmFunc`.
/// Reports error to `loc` if any and returns immediately. Expects `attributes`
/// to be an array attribute containing either string attributes, treated as
@@ -1448,6 +1483,42 @@ static void convertFunctionAttributes(LLVMFuncOp func,
convertFunctionMemoryAttributes(func, llvmFunc);
}
+/// Converts function attributes from `func` and attaches them to `llvmFunc`.
+template <typename TypeConverter>
+static void convertFunctionKernelAttributes(LLVMFuncOp func,
+ llvm::Function *llvmFunc,
+ TypeConverter convertType) {
+ llvm::LLVMContext &llvmContext = llvmFunc->getContext();
+
+ if (auto vecTypeHint = func.getVecTypeHint()) {
+ Type type = vecTypeHint->getHint().getValue();
+ llvm::Type *llvmType = convertType(type);
+ bool isSigned = vecTypeHint->getIsSigned();
+ llvmFunc->setMetadata(
+ "vec_type_hint",
+ convertVecTypeHintToMDNode(llvmContext, llvmType, isSigned));
+ }
+
+ if (auto workGroupSizeHint = func.getWorkGroupSizeHint()) {
+ llvmFunc->setMetadata(
+ "work_group_size_hint",
+ convertIntegerArrayToMDNode(llvmContext, *workGroupSizeHint));
+ }
+
+ if (auto reqdWorkGroupSize = func.getReqdWorkGroupSize()) {
+ llvmFunc->setMetadata(
+ "reqd_work_group_size",
+ convertIntegerArrayToMDNode(llvmContext, *reqdWorkGroupSize));
+ }
+
+ if (auto intelReqdSubGroupSize = func.getIntelReqdSubGroupSize()) {
+ llvmFunc->setMetadata(
+ "intel_reqd_sub_group_size",
+ convertIntegerToMDNode(llvmContext,
+ llvm::APInt(32, *intelReqdSubGroupSize)));
+ }
+}
+
FailureOr<llvm::AttrBuilder>
ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
DictionaryAttr paramAttrs) {
@@ -1492,6 +1563,10 @@ LogicalResult ModuleTranslation::convertFunctionSignatures() {
// Convert function attributes.
convertFunctionAttributes(function, llvmFunc);
+ // Convert function kernel attributes to metadata
+ convertFunctionKernelAttributes(
+ function, llvmFunc, [this](Type type) { return convertType(type); });
+
// Convert function_entry_count attribute to metadata.
if (std::optional<uint64_t> entryCount = function.getFunctionEntryCount())
llvmFunc->setEntryCount(entryCount.value());
diff --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir
index 0e29a548de72f..40b4e49f08a3e 100644
--- a/mlir/test/Dialect/LLVMIR/func.mlir
+++ b/mlir/test/Dialect/LLVMIR/func.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s
// RUN: mlir-opt -split-input-file -verify-diagnostics -mlir-print-op-generic %s | FileCheck %s --check-prefix=GENERIC
-// RUN: mlir-opt -split-input-file -verify-diagnostics %s -mlir-print-debuginfo | mlir-opt -mlir-print-debuginfo | FileCheck %s --check-prefix=LOCINFO
+// RUN: mlir-opt -split-input-file -verify-diagnostics -mlir-print-debuginfo %s | mlir-opt -split-input-file -mlir-print-debuginfo | FileCheck %s --check-prefix=LOCINFO
// RUN: mlir-translate -mlir-to-llvmir -split-input-file -verify-diagnostics %s | FileCheck %s --check-prefix=CHECK-LLVM
module {
@@ -432,3 +432,43 @@ module {
// expected-error @below {{failed to parse CConvAttr parameter 'CallingConv' which is to be a `CConv`}}
}) {sym_name = "generic_unknown_calling_convention", CConv = #llvm.cconv<cc_12>, function_type = !llvm.func<i64 (i64, i64)>} : () -> ()
}
+
+// -----
+
+// CHECK: @vec_type_hint()
+// CHECK-SAME: vec_type_hint = #llvm.vec_type_hint<hint = i32>
+llvm.func @vec_type_hint() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32>}
+
+// CHECK: @vec_type_hint_signed()
+// CHECK-SAME: vec_type_hint = #llvm.vec_type_hint<hint = i32, is_signed = true>
+llvm.func @vec_type_hint_signed() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32, is_signed = true>}
+
+// CHECK: @vec_type_hint_signed_vec()
+// CHECK-SAME: vec_type_hint = #llvm.vec_type_hint<hint = vector<2xi32>, is_signed = true>
+llvm.func @vec_type_hint_signed_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<2xi32>, is_signed = true>}
+
+// CHECK: @vec_type_hint_float_vec()
+// CHECK-SAME: vec_type_hint = #llvm.vec_type_hint<hint = vector<3xf32>>
+llvm.func @vec_type_hint_float_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<3xf32>>}
+
+// CHECK: @vec_type_hint_bfloat_vec()
+// CHECK-SAME: vec_type_hint = #llvm.vec_type_hint<hint = vector<8xbf16>>
+llvm.func @vec_type_hint_bfloat_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<8xbf16>>}
+
+// -----
+
+// CHECK: @work_group_size_hint()
+// CHECK-SAME: work_group_size_hint = array<i32: 128, 128, 128>
+llvm.func @work_group_size_hint() attributes {work_group_size_hint = array<i32: 128, 128, 128>}
+
+// -----
+
+// CHECK: @reqd_work_group_size_hint()
+// CHECK-SAME: reqd_work_group_size = array<i32: 128, 256, 128>
+llvm.func @reqd_work_group_size_hint() attributes {reqd_work_group_size = array<i32: 128, 256, 128>}
+
+// -----
+
+// CHECK: @intel_reqd_sub_group_size_hint()
+// CHECK-SAME: intel_reqd_sub_group_size = 32 : i32
+llvm.func @intel_reqd_sub_group_size_hint() attributes {llvm.intel_reqd_sub_group_size = 32 : i32}
diff --git a/mlir/test/Target/LLVMIR/Import/metadata-kernel.ll b/mlir/test/Target/LLVMIR/Import/metadata-kernel.ll
new file mode 100644
index 0000000000000..963e40348c795
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/metadata-kernel.ll
@@ -0,0 +1,34 @@
+; RUN: mlir-translate -import-llvm %s | FileCheck %s
+
+; CHECK: llvm.func @vec_type_hint() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32>}
+declare !vec_type_hint !0 void @vec_type_hint()
+
+; CHECK: llvm.func @vec_type_hint_signed() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32, is_signed = true>}
+declare !vec_type_hint !1 void @vec_type_hint_signed()
+
+; CHECK: llvm.func @vec_type_hint_signed_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<2xi32>, is_signed = true>}
+declare !vec_type_hint !2 void @vec_type_hint_signed_vec()
+
+; CHECK: llvm.func @vec_type_hint_float_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<3xf32>>}
+declare !vec_type_hint !3 void @vec_type_hint_float_vec()
+
+; CHECK: llvm.func @vec_type_hint_bfloat_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<8xbf16>>}
+declare !vec_type_hint !4 void @vec_type_hint_bfloat_vec()
+
+; CHECK: llvm.func @work_group_size_hint() attributes {work_group_size_hint = array<i32: 128, 128, 128>}
+declare !work_group_size_hint !5 void @work_group_size_hint()
+
+; CHECK: llvm.func @reqd_work_group_size() attributes {reqd_work_group_size = array<i32: 128, 256, 128>}
+declare !reqd_work_group_size !6 void @reqd_work_group_size()
+
+; CHECK: llvm.func @intel_reqd_sub_group_size() attributes {intel_reqd_sub_group_size = 32 : i32}
+declare !intel_reqd_sub_group_size !7 void @intel_reqd_sub_group_size()
+
+!0 = !{i32 undef, i32 0}
+!1 = !{i32 undef, i32 1}
+!2 = !{<2 x i32> undef, i32 1}
+!3 = !{<3 x float> undef, i32 0}
+!4 = !{<8 x bfloat> undef, i32 0}
+!5 = !{i32 128, i32 128, i32 128}
+!6 = !{i32 128, i32 256, i32 128}
+!7 = !{i32 32}
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index db54d131299c6..8c2efa3d16f6b 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -2549,3 +2549,47 @@ llvm.func @mem_effects_call() {
// CHECK-SAME: memory(read, inaccessiblemem: write)
// CHECK: #[[ATTRS_3]]
// CHECK-SAME: memory(readwrite, argmem: read)
+
+// -----
+
+// CHECK: declare !vec_type_hint ![[#VEC_TYPE_HINT:]] void @vec_type_hint()
+llvm.func @vec_type_hint() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32>}
+
+// CHECK: declare !vec_type_hint ![[#VEC_TYPE_HINT_SIGNED:]] void @vec_type_hint_signed()
+llvm.func @vec_type_hint_signed() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32, is_signed = true>}
+
+// CHECK: declare !vec_type_hint ![[#VEC_TYPE_HINT_SIGNED_VEC:]] void @vec_type_hint_signed_vec()
+llvm.func @vec_type_hint_signed_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<2xi32>, is_signed = true>}
+
+// CHECK: declare !vec_type_hint ![[#VEC_TYPE_HINT_FVEC:]] void @vec_type_hint_float_vec()
+llvm.func @vec_type_hint_float_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<3xf32>>}
+
+// CHECK: declare !vec_type_hint ![[#VEC_TYPE_HINT_BFVEC:]] void @vec_type_hint_bfloat_vec()
+llvm.func @vec_type_hint_bfloat_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<8xbf16>>}
+
+// CHECK: ![[#VEC_TYPE_HINT]] = !{i32 undef, i32 0}
+// CHECK: ![[#VEC_TYPE_HINT_SIGNED]] = !{i32 undef, i32 1}
+// CHECK: ![[#VEC_TYPE_HINT_SIGNED_VEC]] = !{<2 x i32> undef, i32 1}
+// CHECK: ![[#VEC_TYPE_HINT_FVEC]] = !{<3 x float> undef, i32 0}
+// CHECK: ![[#VEC_TYPE_HINT_BFVEC]] = !{<8 x bfloat> undef, i32 0}
+
+// -----
+
+// CHECK: declare !work_group_size_hint ![[#WORK_GROUP_SIZE_HINT:]] void @work_group_size_hint()
+llvm.func @work_group_size_hint() attributes {work_group_size_hint = array<i32: 128, 128, 128>}
+
+// CHECK: ![[#WORK_GROUP_SIZE_HINT]] = !{i32 128, i32 128, i32 128}
+
+// -----
+
+// CHECK: declare !reqd_work_group_size ![[#REQD_WORK_GROUP_SIZE:]] void @reqd_work_group_size()
+llvm.func @reqd_work_group_size() attributes {reqd_work_group_size = array<i32: 128, 256, 128>}
+
+// CHECK: ![[#REQD_WORK_GROUP_SIZE]] = !{i32 128, i32 256, i32 128}
+
+// -----
+
+// CHECK: declare !intel_reqd_sub_group_size ![[#INTEL_REQD_SUB_GROUP_SIZE:]] void @intel_reqd_sub_group_size()
+llvm.func @intel_reqd_sub_group_size() attributes {intel_reqd_sub_group_size = 32 : i32}
+
+// CHECK: ![[#INTEL_REQD_SUB_GROUP_SIZE]] = !{i32 32}
|
We cannot handle this metadata via the |
This can be added to the list of supported metadata by using the context's |
I think this may only require changes to the interface signatures. I can give it a try and go with that solution if it's less invasive. Thanks for the suggestion! |
@Dinistro done! Now |
Signed-off-by: Victor Perez <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for using / extending the dialect interface.
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM modulo last nits.
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
Outdated
Show resolved
Hide resolved
Signed-off-by: Victor Perez <[email protected]>
@gysit thanks for the comments! Looking way better now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty good, just a couple questions
Add optional attributes to
llvm.func
representing LLVM so-called "kernel" metadata:vec_type_hint
work_group_size_hint
reqd_work_group_size
intel_reqd_sub_group_size
.