Skip to content

Commit 6168e82

Browse files
authored
[MLIR][LLVM] Add inlining support for loop annotations (#94447)
This commit extends the LLVM dialect's inliner interface support updating loop annotation attributes. This is necessary because the loop annotations can contain debug locations, which are verified by LLVM's verifier. LLVM requires these locations to have the same scope as the function this attribute is contained in.
1 parent c70fa55 commit 6168e82

File tree

2 files changed

+108
-0
lines changed

2 files changed

+108
-0
lines changed

mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,57 @@ static void handleAccessGroups(Operation *call,
511511
accessGroupOpInterface.getAccessGroupsOrNull(), accessGroups));
512512
}
513513

514+
/// Updates locations inside loop annotations to reflect that they were inlined.
515+
static void
516+
handleLoopAnnotations(Operation *call,
517+
iterator_range<Region::iterator> inlinedBlocks) {
518+
// Attempt to extract a DISubprogram from the callee.
519+
auto func = call->getParentOfType<FunctionOpInterface>();
520+
if (!func)
521+
return;
522+
LocationAttr funcLoc = func->getLoc();
523+
auto fusedLoc = dyn_cast_if_present<FusedLoc>(funcLoc);
524+
if (!fusedLoc)
525+
return;
526+
auto scope =
527+
dyn_cast_if_present<LLVM::DISubprogramAttr>(fusedLoc.getMetadata());
528+
if (!scope)
529+
return;
530+
531+
// Helper to build a new fused location that reflects the inlining of the loop
532+
// annotation.
533+
auto updateLoc = [&](FusedLoc loc) -> FusedLoc {
534+
if (!loc)
535+
return {};
536+
Location callSiteLoc = CallSiteLoc::get(loc, call->getLoc());
537+
return FusedLoc::get(loc.getContext(), callSiteLoc, scope);
538+
};
539+
540+
AttrTypeReplacer replacer;
541+
replacer.addReplacement([&](LLVM::LoopAnnotationAttr loopAnnotation)
542+
-> std::pair<Attribute, WalkResult> {
543+
FusedLoc newStartLoc = updateLoc(loopAnnotation.getStartLoc());
544+
FusedLoc newEndLoc = updateLoc(loopAnnotation.getEndLoc());
545+
if (!newStartLoc && !newEndLoc)
546+
return {loopAnnotation, WalkResult::advance()};
547+
auto newLoopAnnotation = LLVM::LoopAnnotationAttr::get(
548+
loopAnnotation.getContext(), loopAnnotation.getDisableNonforced(),
549+
loopAnnotation.getVectorize(), loopAnnotation.getInterleave(),
550+
loopAnnotation.getUnroll(), loopAnnotation.getUnrollAndJam(),
551+
loopAnnotation.getLicm(), loopAnnotation.getDistribute(),
552+
loopAnnotation.getPipeline(), loopAnnotation.getPeeled(),
553+
loopAnnotation.getUnswitch(), loopAnnotation.getMustProgress(),
554+
loopAnnotation.getIsVectorized(), newStartLoc, newEndLoc,
555+
loopAnnotation.getParallelAccesses());
556+
// Needs to advance, as loop annotations can be nested.
557+
return {newLoopAnnotation, WalkResult::advance()};
558+
});
559+
560+
for (Block &block : inlinedBlocks)
561+
for (Operation &op : block)
562+
replacer.recursivelyReplaceElementsIn(&op);
563+
}
564+
514565
/// If `requestedAlignment` is higher than the alignment specified on `alloca`,
515566
/// realigns `alloca` if this does not exceed the natural stack alignment.
516567
/// Returns the post-alignment of `alloca`, whether it was realigned or not.
@@ -784,6 +835,7 @@ struct LLVMInlinerInterface : public DialectInlinerInterface {
784835
handleInlinedAllocas(call, inlinedBlocks);
785836
handleAliasScopes(call, inlinedBlocks);
786837
handleAccessGroups(call, inlinedBlocks);
838+
handleLoopAnnotations(call, inlinedBlocks);
787839
}
788840

789841
// Keeping this (immutable) state on the interface allows us to look up
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// RUN: mlir-opt %s -inline -split-input-file | FileCheck %s
2+
3+
#di_file = #llvm.di_file<"file.mlir" in "/">
4+
5+
// CHECK: #[[START_ORIGINAL:.*]] = loc({{.*}}:42
6+
#loc1 = loc("test.mlir":42:4)
7+
// CHECK: #[[END_ORIGINAL:.*]] = loc({{.*}}:52
8+
#loc2 = loc("test.mlir":52:4)
9+
#loc3 = loc("test.mlir":62:4)
10+
// CHECK: #[[CALL_ORIGINAL:.*]] = loc({{.*}}:72
11+
#loc4 = loc("test.mlir":72:4)
12+
13+
#di_compile_unit = #llvm.di_compile_unit<id = distinct[0]<>, sourceLanguage = DW_LANG_C, file = #di_file, isOptimized = false, emissionKind = None>
14+
// CHECK: #[[CALLEE_DI:.*]] = #llvm.di_subprogram<{{.*}}, name = "callee"
15+
#di_subprogram_callee = #llvm.di_subprogram<compileUnit = #di_compile_unit, scope = #di_file, name = "callee", file = #di_file, subprogramFlags = Definition>
16+
17+
// CHECK: #[[CALLER_DI:.*]] = #llvm.di_subprogram<{{.*}}, name = "caller"
18+
#di_subprogram_caller = #llvm.di_subprogram<compileUnit = #di_compile_unit, scope = #di_file, name = "caller", file = #di_file, subprogramFlags = Definition>
19+
20+
// CHECK: #[[START_FUSED_ORIGINAL:.*]] = loc(fused<#[[CALLEE_DI]]>[#[[START_ORIGINAL]]
21+
#start_loc_fused = loc(fused<#di_subprogram_callee>[#loc1])
22+
// CHECK: #[[END_FUSED_ORIGINAL:.*]] = loc(fused<#[[CALLEE_DI]]>[#[[END_ORIGINAL]]
23+
#end_loc_fused= loc(fused<#di_subprogram_callee>[#loc2])
24+
#caller_loc= loc(fused<#di_subprogram_caller>[#loc3])
25+
// CHECK: #[[CALL_FUSED:.*]] = loc(fused<#[[CALLER_DI]]>[#[[CALL_ORIGINAL]]
26+
#call_loc= loc(fused<#di_subprogram_caller>[#loc4])
27+
28+
#loopMD = #llvm.loop_annotation<
29+
startLoc = #start_loc_fused,
30+
endLoc = #end_loc_fused>
31+
32+
// CHECK: #[[START_CALLSITE_LOC:.*]] = loc(callsite(#[[START_FUSED_ORIGINAL]] at #[[CALL_FUSED]]
33+
// CHECK: #[[END_CALLSITE_LOC:.*]] = loc(callsite(#[[END_FUSED_ORIGINAL]] at #[[CALL_FUSED]]
34+
// CHECK: #[[START_FUSED_LOC:.*]] = loc(fused<#[[CALLER_DI]]>[#[[START_CALLSITE_LOC]]
35+
// CHECK: #[[END_FUSED_LOC:.*]] = loc(fused<#[[CALLER_DI]]>[
36+
// CHECK: #[[LOOP_ANNOT:.*]] = #llvm.loop_annotation<
37+
// CHECK-SAME: startLoc = #[[START_FUSED_LOC]], endLoc = #[[END_FUSED_LOC]]>
38+
39+
llvm.func @cond() -> i1
40+
41+
llvm.func @callee() {
42+
llvm.br ^head
43+
^head:
44+
%c = llvm.call @cond() : () -> i1
45+
llvm.cond_br %c, ^head, ^exit {loop_annotation = #loopMD}
46+
^exit:
47+
llvm.return
48+
}
49+
50+
// CHECK: @loop_annotation
51+
llvm.func @loop_annotation() {
52+
// CHECK: llvm.cond_br
53+
// CHECK-SAME: {loop_annotation = #[[LOOP_ANNOT]]
54+
llvm.call @callee() : () -> () loc(#call_loc)
55+
llvm.return
56+
} loc(#caller_loc)

0 commit comments

Comments
 (0)