@@ -136,7 +136,8 @@ bool isESIMDFunction(const Function &F) {
136
136
// This function makes one or two groups depending on kernel types (SYCL, ESIMD)
137
137
EntryPointGroupVec
138
138
groupEntryPointsByKernelType (const Module &M, bool EmitOnlyKernelsAsEntryPoints,
139
- EntryPointVec *AllowedEntriesVec) {
139
+ EntryPointVec *AllowedEntriesVec,
140
+ EntryPointGroup::Properties BlueprintProps) {
140
141
SmallPtrSet<const Function *, 32 > AllowedEntries;
141
142
142
143
if (AllowedEntriesVec) {
@@ -161,7 +162,8 @@ groupEntryPointsByKernelType(const Module &M, bool EmitOnlyKernelsAsEntryPoints,
161
162
162
163
if (!EntryPointMap.empty ()) {
163
164
for (auto & EPG : EntryPointMap) {
164
- EntryPointGroups.push_back ({ EPG.first , std::move (EPG.second ) });
165
+ EntryPointGroups.emplace_back (
166
+ EntryPointGroup{EPG.first , std::move (EPG.second ), BlueprintProps});
165
167
EntryPointGroup& G = EntryPointGroups.back ();
166
168
167
169
if (G.GroupId == ESIMD_SCOPE_NAME) {
@@ -185,9 +187,10 @@ groupEntryPointsByKernelType(const Module &M, bool EmitOnlyKernelsAsEntryPoints,
185
187
// which contains pairs of group id and entry points for that group. Each such
186
188
// group along with IR it depends on (globals, functions from its call graph,
187
189
// ...) will constitute a separate module.
188
- EntryPointGroupVec groupEntryPointsByScope (const Module &M,
189
- EntryPointsGroupScope EntryScope,
190
- bool EmitOnlyKernelsAsEntryPoints) {
190
+ EntryPointGroupVec
191
+ groupEntryPointsByScope (const Module &M, EntryPointsGroupScope EntryScope,
192
+ bool EmitOnlyKernelsAsEntryPoints,
193
+ EntryPointGroup::Properties BlueprintProps) {
191
194
EntryPointGroupVec EntryPointGroups{};
192
195
// Use MapVector for deterministic order of traversal (helps tests).
193
196
MapVector<StringRef, EntryPointVec> EntryPointMap;
@@ -227,7 +230,8 @@ EntryPointGroupVec groupEntryPointsByScope(const Module &M,
227
230
if (!EntryPointMap.empty ()) {
228
231
EntryPointGroups.reserve (EntryPointMap.size ());
229
232
for (auto & EPG : EntryPointMap) {
230
- EntryPointGroups.push_back ({ EPG.first , std::move (EPG.second ) });
233
+ EntryPointGroups.emplace_back (
234
+ EntryPointGroup{EPG.first , std::move (EPG.second ), BlueprintProps});
231
235
EntryPointGroup& G = EntryPointGroups.back ();
232
236
G.Props .Scope = EntryScope;
233
237
}
@@ -239,10 +243,9 @@ EntryPointGroupVec groupEntryPointsByScope(const Module &M,
239
243
}
240
244
241
245
template <class EntryPoinGroupFunc >
242
- EntryPointGroupVec
243
- groupEntryPointsByAttribute (const Module &M, StringRef AttrName,
244
- bool EmitOnlyKernelsAsEntryPoints,
245
- EntryPoinGroupFunc F) {
246
+ EntryPointGroupVec groupEntryPointsByAttribute (
247
+ const Module &M, StringRef AttrName, bool EmitOnlyKernelsAsEntryPoints,
248
+ EntryPoinGroupFunc F, EntryPointGroup::Properties BlueprintProps) {
246
249
EntryPointGroupVec EntryPointGroups{};
247
250
std::map<StringRef, EntryPointVec> EntryPointMap;
248
251
@@ -260,7 +263,8 @@ groupEntryPointsByAttribute(const Module &M, StringRef AttrName,
260
263
if (!EntryPointMap.empty ()) {
261
264
EntryPointGroups.reserve (EntryPointMap.size ());
262
265
for (auto & EPG : EntryPointMap) {
263
- EntryPointGroups.push_back ({ EPG.first , std::move (EPG.second ) });
266
+ EntryPointGroups.emplace_back (
267
+ EntryPointGroup{EPG.first , std::move (EPG.second ), BlueprintProps});
264
268
F (EntryPointGroups.back ());
265
269
}
266
270
} else {
@@ -271,11 +275,9 @@ groupEntryPointsByAttribute(const Module &M, StringRef AttrName,
271
275
return EntryPointGroups;
272
276
}
273
277
274
- // Records a use graph between functions in a module. Nodes are functions, edges
275
- // are "uses" relation. One function "uses" another if any of its instructions
276
- // use the other function. Typical use is a call, another example of use is
277
- // storing
278
- class FunctionUseGraph {
278
+ // Represents a call graph between functions in a module. Nodes are functions,
279
+ // edges are "calls" relation.
280
+ class CallGraph {
279
281
public:
280
282
using FunctionSet = SmallPtrSet<const Function*, 16 >;
281
283
@@ -284,11 +286,11 @@ class FunctionUseGraph {
284
286
SmallPtrSet<const Function*, 1 > EmptySet;
285
287
286
288
public:
287
- FunctionUseGraph (const Module& M, bool AllUsesAreGraphEdges = false ) {
289
+ CallGraph (const Module &M ) {
288
290
for (const auto & F : M) {
289
291
for (const Value* U : F.users ()) {
290
- if (const Instruction* I = dyn_cast<Instruction >(U)) {
291
- if (AllUsesAreGraphEdges || dyn_cast< const CallInst>(I) ) {
292
+ if (const auto * I = dyn_cast<CallInst >(U)) {
293
+ if (I-> getCalledFunction () == &F ) {
292
294
const Function* F1 = I->getFunction ();
293
295
Graph[F1].insert (&F);
294
296
}
@@ -305,7 +307,7 @@ class FunctionUseGraph {
305
307
306
308
void collectFunctionsToExtract (SetVector<const GlobalValue *> &GVs,
307
309
const EntryPointGroup &ModuleEntryPoints,
308
- const FunctionUseGraph &Deps) {
310
+ const CallGraph &Deps) {
309
311
for (const auto *F : ModuleEntryPoints.Functions )
310
312
GVs.insert (F);
311
313
@@ -336,9 +338,10 @@ void collectGlobalVarsToExtract(SetVector<const GlobalValue *> &GVs,
336
338
GVs.insert (&G);
337
339
}
338
340
339
- ModuleDesc extractSubModule (const Module &M ,
341
+ ModuleDesc extractSubModule (const ModuleDesc &MD ,
340
342
const SetVector<const GlobalValue *> GVs,
341
343
EntryPointGroup &&ModuleEntryPoints) {
344
+ const Module &M = MD.getModule ();
342
345
// For each group of entry points collect all dependencies.
343
346
ValueToValueMapTy VMap;
344
347
// Clone definitions only for needed globals. Others will be added as
@@ -353,7 +356,7 @@ ModuleDesc extractSubModule(const Module &M,
353
356
EPs.cbegin (), EPs.cend (), std::inserter (NewEPs, NewEPs.end ()),
354
357
[&VMap](const Function *F) { return cast<Function>(VMap[F]); });
355
358
ModuleEntryPoints.Functions = std::move (NewEPs);
356
- return ModuleDesc{std::move (SubM), std::move (ModuleEntryPoints)};
359
+ return ModuleDesc{std::move (SubM), std::move (ModuleEntryPoints), MD. Props };
357
360
}
358
361
359
362
// TODO: try to move including all passes (cleanup, spec consts, compile time
@@ -371,14 +374,13 @@ void cleanupSplitModule(Module &SplitM) {
371
374
372
375
// The function produces a copy of input LLVM IR module M with only those entry
373
376
// points that are specified in ModuleEntryPoints vector.
374
- ModuleDesc extractCallGraph (const Module &M,
375
- EntryPointGroup &&ModuleEntryPoints,
376
- bool AllUsesAreCallGraphEdges = false ) {
377
+ ModuleDesc extractCallGraph (const ModuleDesc &MD,
378
+ EntryPointGroup &&ModuleEntryPoints) {
377
379
SetVector<const GlobalValue *> GVs;
378
- collectFunctionsToExtract (GVs, ModuleEntryPoints, FunctionUseGraph{M, AllUsesAreCallGraphEdges });
379
- collectGlobalVarsToExtract (GVs, M );
380
+ collectFunctionsToExtract (GVs, ModuleEntryPoints, CallGraph{MD. getModule () });
381
+ collectGlobalVarsToExtract (GVs, MD. getModule () );
380
382
381
- ModuleDesc SplitM = extractSubModule (M , GVs, std::move (ModuleEntryPoints));
383
+ ModuleDesc SplitM = extractSubModule (MD , GVs, std::move (ModuleEntryPoints));
382
384
cleanupSplitModule (SplitM.getModule ());
383
385
384
386
return SplitM;
@@ -389,19 +391,17 @@ class ModuleCopier : public ModuleSplitterBase {
389
391
using ModuleSplitterBase::ModuleSplitterBase; // to inherit base constructors
390
392
391
393
ModuleDesc nextSplit () override {
392
- return {releaseInputModule (), nextGroup ()};
394
+ return ModuleDesc {releaseInputModule (), nextGroup (), Input. Props };
393
395
}
394
396
};
395
397
396
398
class ModuleSplitter : public ModuleSplitterBase {
397
- bool AllUsesAreCallGraphEdges;
398
-
399
399
public:
400
-
401
- ModuleSplitter (std::unique_ptr<Module> M, EntryPointGroupVec&& GroupVec, bool AllUsesAreCallGraphEdges = false ) : ModuleSplitterBase(std::move(M ), std::move(GroupVec)), AllUsesAreCallGraphEdges(AllUsesAreCallGraphEdges ) {}
400
+ ModuleSplitter (ModuleDesc &&MD, EntryPointGroupVec &&GroupVec)
401
+ : ModuleSplitterBase(std::move(MD ), std::move(GroupVec)) {}
402
402
403
403
ModuleDesc nextSplit () override {
404
- return extractCallGraph (getInputModule () , nextGroup (), AllUsesAreCallGraphEdges );
404
+ return extractCallGraph (Input , nextGroup ());
405
405
}
406
406
};
407
407
@@ -411,39 +411,40 @@ namespace llvm {
411
411
namespace module_split {
412
412
413
413
std::unique_ptr<ModuleSplitterBase>
414
- getSplitterByKernelType (std::unique_ptr<Module> M,
415
- bool EmitOnlyKernelsAsEntryPoints,
414
+ getSplitterByKernelType (ModuleDesc &&MD, bool EmitOnlyKernelsAsEntryPoints,
416
415
EntryPointVec *AllowedEntries) {
417
416
EntryPointGroupVec Groups = groupEntryPointsByKernelType (
418
- *M, EmitOnlyKernelsAsEntryPoints, AllowedEntries);
417
+ MD.getModule (), EmitOnlyKernelsAsEntryPoints, AllowedEntries,
418
+ MD.getEntryPointGroup ().Props );
419
419
bool DoSplit = (Groups.size () > 1 );
420
420
421
421
if (DoSplit)
422
- return std::make_unique<ModuleSplitter>(std::move (M ), std::move (Groups));
422
+ return std::make_unique<ModuleSplitter>(std::move (MD ), std::move (Groups));
423
423
else
424
- return std::make_unique<ModuleCopier>(std::move (M ), std::move (Groups));
424
+ return std::make_unique<ModuleCopier>(std::move (MD ), std::move (Groups));
425
425
}
426
426
427
427
std::unique_ptr<ModuleSplitterBase>
428
- getSplitterByMode (std::unique_ptr<Module> M , IRSplitMode Mode,
428
+ getSplitterByMode (ModuleDesc &&MD , IRSplitMode Mode,
429
429
bool AutoSplitIsGlobalScope,
430
430
bool EmitOnlyKernelsAsEntryPoints) {
431
431
EntryPointsGroupScope Scope =
432
- selectDeviceCodeGroupScope (*M, Mode, AutoSplitIsGlobalScope);
433
- EntryPointGroupVec Groups =
434
- groupEntryPointsByScope (*M, Scope, EmitOnlyKernelsAsEntryPoints);
432
+ selectDeviceCodeGroupScope (MD.getModule (), Mode, AutoSplitIsGlobalScope);
433
+ EntryPointGroupVec Groups = groupEntryPointsByScope (
434
+ MD.getModule (), Scope, EmitOnlyKernelsAsEntryPoints,
435
+ MD.getEntryPointGroup ().Props );
435
436
assert (!Groups.empty () && " At least one group is expected" );
436
437
bool DoSplit = (Mode != SPLIT_NONE &&
437
438
(Groups.size () > 1 || !Groups.cbegin ()->Functions .empty ()));
438
439
439
440
if (DoSplit)
440
- return std::make_unique<ModuleSplitter>(std::move (M ), std::move (Groups));
441
+ return std::make_unique<ModuleSplitter>(std::move (MD ), std::move (Groups));
441
442
else
442
- return std::make_unique<ModuleCopier>(std::move (M ), std::move (Groups));
443
+ return std::make_unique<ModuleCopier>(std::move (MD ), std::move (Groups));
443
444
}
444
445
445
446
void ModuleSplitterBase::verifyNoCrossModuleDeviceGlobalUsage () {
446
- const Module &M = *InputModule ;
447
+ const Module &M = getInputModule () ;
447
448
// Early exit if there is only one group
448
449
if (Groups.size () < 2 )
449
450
return ;
@@ -669,21 +670,22 @@ void EntryPointGroup::rebuild(const Module &M) {
669
670
}
670
671
671
672
std::unique_ptr<ModuleSplitterBase>
672
- getESIMDDoubleGRFSplitter (std::unique_ptr<Module> M,
673
- bool EmitOnlyKernelsAsEntryPoints) {
673
+ getESIMDDoubleGRFSplitter (ModuleDesc &&MD, bool EmitOnlyKernelsAsEntryPoints) {
674
674
EntryPointGroupVec Groups = groupEntryPointsByAttribute (
675
- *M, ATTR_DOUBLE_GRF, EmitOnlyKernelsAsEntryPoints, [](EntryPointGroup& G) {
676
- if (G.GroupId == ATTR_DOUBLE_GRF) {
677
- G.Props .UsesDoubleGRF = true ;
678
- }
679
- });
675
+ MD.getModule (), ATTR_DOUBLE_GRF, EmitOnlyKernelsAsEntryPoints,
676
+ [](EntryPointGroup &G) {
677
+ if (G.GroupId == ATTR_DOUBLE_GRF) {
678
+ G.Props .UsesDoubleGRF = true ;
679
+ }
680
+ },
681
+ MD.getEntryPointGroup ().Props );
680
682
assert (!Groups.empty () && " At least one group is expected" );
681
683
assert (Groups.size () <= 2 && " At most 2 groups are expected" );
682
684
683
685
if (Groups.size () > 1 )
684
- return std::make_unique<ModuleSplitter>(std::move (M ), std::move (Groups));
686
+ return std::make_unique<ModuleSplitter>(std::move (MD ), std::move (Groups));
685
687
else
686
- return std::make_unique<ModuleCopier>(std::move (M ), std::move (Groups));
688
+ return std::make_unique<ModuleCopier>(std::move (MD ), std::move (Groups));
687
689
}
688
690
689
691
} // namespace module_split
0 commit comments