Skip to content

Commit 84de9d6

Browse files
authored
[SYCL] Improve sycl-post-link performance with -split=kernel (#6689)
Right now we are computing a new callgraph in every call to extractCallGraph. extractCallGraph is called every time we do a module split, so for -split=kernel, that would be once per kernel. For modules with many kernels, this can take a very long time. We only need to compute this once because the input IR doesn't seem to change between splits. This improves performance of sycl-post-link from ~45min to ~7min for an example with 13k kernels Signed-off-by: Sarnie, Nick <[email protected]>
1 parent 32a2777 commit 84de9d6

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

llvm/tools/sycl-post-link/ModuleSplitter.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -391,9 +391,10 @@ ModuleDesc extractSubModule(const ModuleDesc &MD,
391391
// The function produces a copy of input LLVM IR module M with only those entry
392392
// points that are specified in ModuleEntryPoints vector.
393393
ModuleDesc extractCallGraph(const ModuleDesc &MD,
394-
EntryPointGroup &&ModuleEntryPoints) {
394+
EntryPointGroup &&ModuleEntryPoints,
395+
const CallGraph &CG) {
395396
SetVector<const GlobalValue *> GVs;
396-
collectFunctionsToExtract(GVs, ModuleEntryPoints, CallGraph{MD.getModule()});
397+
collectFunctionsToExtract(GVs, ModuleEntryPoints, CG);
397398
collectGlobalVarsToExtract(GVs, MD.getModule());
398399

399400
ModuleDesc SplitM = extractSubModule(MD, GVs, std::move(ModuleEntryPoints));
@@ -414,11 +415,15 @@ class ModuleCopier : public ModuleSplitterBase {
414415
class ModuleSplitter : public ModuleSplitterBase {
415416
public:
416417
ModuleSplitter(ModuleDesc &&MD, EntryPointGroupVec &&GroupVec)
417-
: ModuleSplitterBase(std::move(MD), std::move(GroupVec)) {}
418+
: ModuleSplitterBase(std::move(MD), std::move(GroupVec)),
419+
CG(Input.getModule()) {}
418420

419421
ModuleDesc nextSplit() override {
420-
return extractCallGraph(Input, nextGroup());
422+
return extractCallGraph(Input, nextGroup(), CG);
421423
}
424+
425+
private:
426+
CallGraph CG;
422427
};
423428

424429
} // namespace

0 commit comments

Comments
 (0)