Skip to content

Commit 28934fe

Browse files
[MLIR][LLVM] Add ProfileSummary module flag support (llvm#138070)
Add one more of these module flags. Unlike "CG Profile", LLVM proper does not verify the content of the metadata, but returns a nullptr in case it's ill-formed (it's up to the user to take action). This prompted me to implement warning checks, preventing the importer to consume broken data.
1 parent 75532b2 commit 28934fe

File tree

11 files changed

+678
-16
lines changed

11 files changed

+678
-16
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,6 +1378,54 @@ def ModuleFlagCGProfileEntryAttr
13781378
let assemblyFormat = "`<` struct(params) `>`";
13791379
}
13801380

1381+
def ModuleFlagProfileSummaryDetailedAttr
1382+
: LLVM_Attr<"ModuleFlagProfileSummaryDetailed", "profile_summary_detailed"> {
1383+
let summary = "ProfileSummary detailed information";
1384+
let description = [{
1385+
Contains detailed information pertinent to "ProfileSummary" attribute.
1386+
A `#llvm.profile_summary` may contain several of it.
1387+
```mlir
1388+
llvm.module_flags [ ...
1389+
detailed_summary =
1390+
<cut_off = 10000, min_count = 86427, num_counts = 1>,
1391+
<cut_off = 100000, min_count = 86427, num_counts = 1>
1392+
```
1393+
}];
1394+
let parameters = (ins "uint32_t":$cut_off,
1395+
"uint64_t":$min_count,
1396+
"uint32_t":$num_counts);
1397+
let assemblyFormat = "`<` struct(params) `>`";
1398+
}
1399+
1400+
def ModuleFlagProfileSummaryAttr
1401+
: LLVM_Attr<"ModuleFlagProfileSummary", "profile_summary"> {
1402+
let summary = "ProfileSummary module flag";
1403+
let description = [{
1404+
Describes ProfileSummary gathered data in a module. Example:
1405+
```mlir
1406+
llvm.module_flags [#llvm.mlir.module_flag<error, "ProfileSummary",
1407+
#llvm.profile_summary<format = InstrProf, total_count = 263646, max_count = 86427,
1408+
max_internal_count = 86427, max_function_count = 4691,
1409+
num_counts = 3712, num_functions = 796,
1410+
is_partial_profile = 0,
1411+
partial_profile_ratio = 0.000000e+00 : f64,
1412+
detailed_summary =
1413+
<cut_off = 10000, min_count = 86427, num_counts = 1>,
1414+
<cut_off = 100000, min_count = 86427, num_counts = 1>
1415+
>>]
1416+
```
1417+
}];
1418+
let parameters = (ins "ProfileSummaryFormatKind":$format,
1419+
"uint64_t":$total_count, "uint64_t":$max_count,
1420+
"uint64_t":$max_internal_count, "uint64_t":$max_function_count,
1421+
"uint64_t":$num_counts, "uint64_t":$num_functions,
1422+
OptionalParameter<"std::optional<uint64_t>">:$is_partial_profile,
1423+
OptionalParameter<"FloatAttr">:$partial_profile_ratio,
1424+
ArrayRefParameter<"ModuleFlagProfileSummaryDetailedAttr">:$detailed_summary);
1425+
1426+
let assemblyFormat = "`<` struct(params) `>`";
1427+
}
1428+
13811429
//===----------------------------------------------------------------------===//
13821430
// LLVM_DependentLibrariesAttr
13831431
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ def LLVM_Dialect : Dialect {
9292
static StringRef getModuleFlagKeyCGProfileName() {
9393
return "CG Profile";
9494
}
95+
static StringRef getModuleFlagKeyProfileSummaryName() {
96+
return "ProfileSummary";
97+
}
9598

9699
/// Returns `true` if the given type is compatible with the LLVM dialect.
97100
static bool isCompatibleType(Type);

mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -826,7 +826,7 @@ def FPExceptionBehaviorAttr : LLVM_EnumAttr<
826826
}
827827

828828
//===----------------------------------------------------------------------===//
829-
// Module Flag Behavior
829+
// Module Flags
830830
//===----------------------------------------------------------------------===//
831831

832832
// These values must match llvm::Module::ModFlagBehavior ones.
@@ -858,6 +858,21 @@ def ModFlagBehaviorAttr : LLVM_EnumAttr<
858858
let cppNamespace = "::mlir::LLVM";
859859
}
860860

861+
def LLVM_ProfileSummaryFormatSampleProfile : I64EnumAttrCase<"SampleProfile",
862+
0>;
863+
def LLVM_ProfileSummaryFormatInstrProf : I64EnumAttrCase<"InstrProf", 1>;
864+
def LLVM_ProfileSummaryFormatCSInstrProf : I64EnumAttrCase<"CSInstrProf", 2>;
865+
866+
def LLVM_ProfileSummaryFormatKind : I64EnumAttr<
867+
"ProfileSummaryFormatKind",
868+
"LLVM ProfileSummary format kinds", [
869+
LLVM_ProfileSummaryFormatSampleProfile,
870+
LLVM_ProfileSummaryFormatInstrProf,
871+
LLVM_ProfileSummaryFormatCSInstrProf,
872+
]> {
873+
let cppNamespace = "::mlir::LLVM";
874+
}
875+
861876
//===----------------------------------------------------------------------===//
862877
// UWTableKind
863878
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,13 @@ ModuleFlagAttr::verify(function_ref<InFlightDiagnostic()> emitError,
390390
return success();
391391
}
392392

393+
if (key == LLVMDialect::getModuleFlagKeyProfileSummaryName()) {
394+
if (!isa<ModuleFlagProfileSummaryAttr>(value))
395+
return emitError() << "'ProfileSummary' key expects a "
396+
"'#llvm.profile_summary' attribute";
397+
return success();
398+
}
399+
393400
if (isa<IntegerAttr, StringAttr>(value))
394401
return success();
395402

mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,69 @@ convertModuleFlagValue(StringRef key, ArrayAttr arrayAttr,
303303
return nullptr;
304304
}
305305

306+
static llvm::Metadata *convertModuleFlagProfileSummaryAttr(
307+
StringRef key, ModuleFlagProfileSummaryAttr summaryAttr,
308+
llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) {
309+
llvm::LLVMContext &context = builder.getContext();
310+
llvm::MDBuilder mdb(context);
311+
SmallVector<llvm::Metadata *> summaryNodes;
312+
313+
auto getIntTuple = [&](StringRef key, uint64_t val) -> llvm::MDTuple * {
314+
SmallVector<llvm::Metadata *> tupleNodes{
315+
mdb.createString(key), mdb.createConstant(llvm::ConstantInt::get(
316+
llvm::Type::getInt64Ty(context), val))};
317+
return llvm::MDTuple::get(context, tupleNodes);
318+
};
319+
320+
SmallVector<llvm::Metadata *> fmtNode{
321+
mdb.createString("ProfileFormat"),
322+
mdb.createString(
323+
stringifyProfileSummaryFormatKind(summaryAttr.getFormat()))};
324+
325+
SmallVector<llvm::Metadata *> vals = {
326+
llvm::MDTuple::get(context, fmtNode),
327+
getIntTuple("TotalCount", summaryAttr.getTotalCount()),
328+
getIntTuple("MaxCount", summaryAttr.getMaxCount()),
329+
getIntTuple("MaxInternalCount", summaryAttr.getMaxInternalCount()),
330+
getIntTuple("MaxFunctionCount", summaryAttr.getMaxFunctionCount()),
331+
getIntTuple("NumCounts", summaryAttr.getNumCounts()),
332+
getIntTuple("NumFunctions", summaryAttr.getNumFunctions()),
333+
};
334+
335+
if (summaryAttr.getIsPartialProfile())
336+
vals.push_back(
337+
getIntTuple("IsPartialProfile", *summaryAttr.getIsPartialProfile()));
338+
339+
if (summaryAttr.getPartialProfileRatio()) {
340+
SmallVector<llvm::Metadata *> tupleNodes{
341+
mdb.createString("PartialProfileRatio"),
342+
mdb.createConstant(llvm::ConstantFP::get(
343+
llvm::Type::getDoubleTy(context),
344+
summaryAttr.getPartialProfileRatio().getValue()))};
345+
vals.push_back(llvm::MDTuple::get(context, tupleNodes));
346+
}
347+
348+
SmallVector<llvm::Metadata *> detailedEntries;
349+
llvm::Type *llvmInt64Type = llvm::Type::getInt64Ty(context);
350+
for (ModuleFlagProfileSummaryDetailedAttr detailedEntry :
351+
summaryAttr.getDetailedSummary()) {
352+
SmallVector<llvm::Metadata *> tupleNodes{
353+
mdb.createConstant(
354+
llvm::ConstantInt::get(llvmInt64Type, detailedEntry.getCutOff())),
355+
mdb.createConstant(
356+
llvm::ConstantInt::get(llvmInt64Type, detailedEntry.getMinCount())),
357+
mdb.createConstant(llvm::ConstantInt::get(
358+
llvmInt64Type, detailedEntry.getNumCounts()))};
359+
detailedEntries.push_back(llvm::MDTuple::get(context, tupleNodes));
360+
}
361+
SmallVector<llvm::Metadata *> detailedSummary{
362+
mdb.createString("DetailedSummary"),
363+
llvm::MDTuple::get(context, detailedEntries)};
364+
vals.push_back(llvm::MDTuple::get(context, detailedSummary));
365+
366+
return llvm::MDNode::get(context, vals);
367+
}
368+
306369
static void convertModuleFlagsOp(ArrayAttr flags, llvm::IRBuilderBase &builder,
307370
LLVM::ModuleTranslation &moduleTranslation) {
308371
llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
@@ -323,6 +386,11 @@ static void convertModuleFlagsOp(ArrayAttr flags, llvm::IRBuilderBase &builder,
323386
arrayAttr, builder,
324387
moduleTranslation);
325388
})
389+
.Case([&](ModuleFlagProfileSummaryAttr summaryAttr) {
390+
return convertModuleFlagProfileSummaryAttr(
391+
flagAttr.getKey().getValue(), summaryAttr, builder,
392+
moduleTranslation);
393+
})
326394
.Default([](auto) { return nullptr; });
327395

328396
assert(valueMetadata && "expected valid metadata");

0 commit comments

Comments
 (0)