8
8
9
9
#include " OffloadWrapper.h"
10
10
#include " llvm/ADT/ArrayRef.h"
11
+ #include " llvm/Frontend/Offloading/Utility.h"
11
12
#include " llvm/IR/Constants.h"
12
13
#include " llvm/IR/GlobalVariable.h"
13
14
#include " llvm/IR/IRBuilder.h"
@@ -39,36 +40,7 @@ enum OffloadEntryKindFlag : uint32_t {
39
40
};
40
41
41
42
IntegerType *getSizeTTy (Module &M) {
42
- LLVMContext &C = M.getContext ();
43
- switch (M.getDataLayout ().getPointerTypeSize (PointerType::getUnqual (C))) {
44
- case 4u :
45
- return Type::getInt32Ty (C);
46
- case 8u :
47
- return Type::getInt64Ty (C);
48
- }
49
- llvm_unreachable (" unsupported pointer type size" );
50
- }
51
-
52
- // struct __tgt_offload_entry {
53
- // void *addr;
54
- // char *name;
55
- // size_t size;
56
- // int32_t flags;
57
- // int32_t reserved;
58
- // };
59
- StructType *getEntryTy (Module &M) {
60
- LLVMContext &C = M.getContext ();
61
- StructType *EntryTy = StructType::getTypeByName (C, " __tgt_offload_entry" );
62
- if (!EntryTy)
63
- EntryTy =
64
- StructType::create (" __tgt_offload_entry" , PointerType::getUnqual (C),
65
- PointerType::getUnqual (C), getSizeTTy (M),
66
- Type::getInt32Ty (C), Type::getInt32Ty (C));
67
- return EntryTy;
68
- }
69
-
70
- PointerType *getEntryPtrTy (Module &M) {
71
- return PointerType::getUnqual (getEntryTy (M));
43
+ return M.getDataLayout ().getIntPtrType (M.getContext ());
72
44
}
73
45
74
46
// struct __tgt_device_image {
@@ -81,9 +53,10 @@ StructType *getDeviceImageTy(Module &M) {
81
53
LLVMContext &C = M.getContext ();
82
54
StructType *ImageTy = StructType::getTypeByName (C, " __tgt_device_image" );
83
55
if (!ImageTy)
84
- ImageTy = StructType::create (
85
- " __tgt_device_image" , PointerType::getUnqual (C),
86
- PointerType::getUnqual (C), getEntryPtrTy (M), getEntryPtrTy (M));
56
+ ImageTy =
57
+ StructType::create (" __tgt_device_image" , PointerType::getUnqual (C),
58
+ PointerType::getUnqual (C), PointerType::getUnqual (C),
59
+ PointerType::getUnqual (C));
87
60
return ImageTy;
88
61
}
89
62
@@ -101,9 +74,9 @@ StructType *getBinDescTy(Module &M) {
101
74
LLVMContext &C = M.getContext ();
102
75
StructType *DescTy = StructType::getTypeByName (C, " __tgt_bin_desc" );
103
76
if (!DescTy)
104
- DescTy = StructType::create (" __tgt_bin_desc " , Type::getInt32Ty (C),
105
- getDeviceImagePtrTy (M ), getEntryPtrTy (M),
106
- getEntryPtrTy (M ));
77
+ DescTy = StructType::create (
78
+ " __tgt_bin_desc " , Type::getInt32Ty (C ), getDeviceImagePtrTy (M),
79
+ PointerType::getUnqual (C), PointerType::getUnqual (C ));
107
80
return DescTy;
108
81
}
109
82
@@ -151,28 +124,8 @@ PointerType *getBinDescPtrTy(Module &M) {
151
124
// / Global variable that represents BinDesc is returned.
152
125
GlobalVariable *createBinDesc (Module &M, ArrayRef<ArrayRef<char >> Bufs) {
153
126
LLVMContext &C = M.getContext ();
154
- // Create external begin/end symbols for the offload entries table.
155
- auto *EntriesB = new GlobalVariable (
156
- M, getEntryTy (M), /* isConstant*/ true , GlobalValue::ExternalLinkage,
157
- /* Initializer*/ nullptr , " __start_omp_offloading_entries" );
158
- EntriesB->setVisibility (GlobalValue::HiddenVisibility);
159
- auto *EntriesE = new GlobalVariable (
160
- M, getEntryTy (M), /* isConstant*/ true , GlobalValue::ExternalLinkage,
161
- /* Initializer*/ nullptr , " __stop_omp_offloading_entries" );
162
- EntriesE->setVisibility (GlobalValue::HiddenVisibility);
163
-
164
- // We assume that external begin/end symbols that we have created above will
165
- // be defined by the linker. But linker will do that only if linker inputs
166
- // have section with "omp_offloading_entries" name which is not guaranteed.
167
- // So, we just create dummy zero sized object in the offload entries section
168
- // to force linker to define those symbols.
169
- auto *DummyInit =
170
- ConstantAggregateZero::get (ArrayType::get (getEntryTy (M), 0u ));
171
- auto *DummyEntry = new GlobalVariable (
172
- M, DummyInit->getType (), true , GlobalVariable::ExternalLinkage, DummyInit,
173
- " __dummy.omp_offloading.entry" );
174
- DummyEntry->setSection (" omp_offloading_entries" );
175
- DummyEntry->setVisibility (GlobalValue::HiddenVisibility);
127
+ auto [EntriesB, EntriesE] =
128
+ offloading::getOffloadEntryArray (M, " omp_offloading_entries" );
176
129
177
130
auto *Zero = ConstantInt::get (getSizeTTy (M), 0u );
178
131
Constant *ZeroZero[] = {Zero, Zero};
@@ -328,18 +281,6 @@ GlobalVariable *createFatbinDesc(Module &M, ArrayRef<char> Image, bool IsHIP) {
328
281
FatbinDesc->setSection (FatbinWrapperSection);
329
282
FatbinDesc->setAlignment (Align (8 ));
330
283
331
- // We create a dummy entry to ensure the linker will define the begin / end
332
- // symbols. The CUDA runtime should ignore the null address if we attempt to
333
- // register it.
334
- auto *DummyInit =
335
- ConstantAggregateZero::get (ArrayType::get (getEntryTy (M), 0u ));
336
- auto *DummyEntry = new GlobalVariable (
337
- M, DummyInit->getType (), true , GlobalVariable::ExternalLinkage, DummyInit,
338
- IsHIP ? " __dummy.hip_offloading.entry" : " __dummy.cuda_offloading.entry" );
339
- DummyEntry->setVisibility (GlobalValue::HiddenVisibility);
340
- DummyEntry->setSection (IsHIP ? " hip_offloading_entries"
341
- : " cuda_offloading_entries" );
342
-
343
284
return FatbinDesc;
344
285
}
345
286
@@ -368,6 +309,9 @@ GlobalVariable *createFatbinDesc(Module &M, ArrayRef<char> Image, bool IsHIP) {
368
309
// / }
369
310
Function *createRegisterGlobalsFunction (Module &M, bool IsHIP) {
370
311
LLVMContext &C = M.getContext ();
312
+ auto [EntriesB, EntriesE] = offloading::getOffloadEntryArray (
313
+ M, IsHIP ? " hip_offloading_entries" : " cuda_offloading_entries" );
314
+
371
315
// Get the __cudaRegisterFunction function declaration.
372
316
PointerType *Int8PtrTy = PointerType::get (C, 0 );
373
317
PointerType *Int8PtrPtrTy = PointerType::get (C, 0 );
@@ -389,22 +333,6 @@ Function *createRegisterGlobalsFunction(Module &M, bool IsHIP) {
389
333
FunctionCallee RegVar = M.getOrInsertFunction (
390
334
IsHIP ? " __hipRegisterVar" : " __cudaRegisterVar" , RegVarTy);
391
335
392
- // Create the references to the start / stop symbols defined by the linker.
393
- auto *EntriesB =
394
- new GlobalVariable (M, ArrayType::get (getEntryTy (M), 0 ),
395
- /* isConstant*/ true , GlobalValue::ExternalLinkage,
396
- /* Initializer*/ nullptr ,
397
- IsHIP ? " __start_hip_offloading_entries"
398
- : " __start_cuda_offloading_entries" );
399
- EntriesB->setVisibility (GlobalValue::HiddenVisibility);
400
- auto *EntriesE =
401
- new GlobalVariable (M, ArrayType::get (getEntryTy (M), 0 ),
402
- /* isConstant*/ true , GlobalValue::ExternalLinkage,
403
- /* Initializer*/ nullptr ,
404
- IsHIP ? " __stop_hip_offloading_entries"
405
- : " __stop_cuda_offloading_entries" );
406
- EntriesE->setVisibility (GlobalValue::HiddenVisibility);
407
-
408
336
auto *RegGlobalsTy = FunctionType::get (Type::getVoidTy (C), Int8PtrPtrTy,
409
337
/* isVarArg*/ false );
410
338
auto *RegGlobalsFn =
@@ -427,24 +355,24 @@ Function *createRegisterGlobalsFunction(Module &M, bool IsHIP) {
427
355
auto *EntryCmp = Builder.CreateICmpNE (EntriesB, EntriesE);
428
356
Builder.CreateCondBr (EntryCmp, EntryBB, ExitBB);
429
357
Builder.SetInsertPoint (EntryBB);
430
- auto *Entry = Builder.CreatePHI (getEntryPtrTy (M ), 2 , " entry" );
358
+ auto *Entry = Builder.CreatePHI (PointerType::getUnqual (C ), 2 , " entry" );
431
359
auto *AddrPtr =
432
- Builder.CreateInBoundsGEP (getEntryTy (M), Entry,
360
+ Builder.CreateInBoundsGEP (offloading:: getEntryTy (M), Entry,
433
361
{ConstantInt::get (getSizeTTy (M), 0 ),
434
362
ConstantInt::get (Type::getInt32Ty (C), 0 )});
435
363
auto *Addr = Builder.CreateLoad (Int8PtrTy, AddrPtr, " addr" );
436
364
auto *NamePtr =
437
- Builder.CreateInBoundsGEP (getEntryTy (M), Entry,
365
+ Builder.CreateInBoundsGEP (offloading:: getEntryTy (M), Entry,
438
366
{ConstantInt::get (getSizeTTy (M), 0 ),
439
367
ConstantInt::get (Type::getInt32Ty (C), 1 )});
440
368
auto *Name = Builder.CreateLoad (Int8PtrTy, NamePtr, " name" );
441
369
auto *SizePtr =
442
- Builder.CreateInBoundsGEP (getEntryTy (M), Entry,
370
+ Builder.CreateInBoundsGEP (offloading:: getEntryTy (M), Entry,
443
371
{ConstantInt::get (getSizeTTy (M), 0 ),
444
372
ConstantInt::get (Type::getInt32Ty (C), 2 )});
445
373
auto *Size = Builder.CreateLoad (getSizeTTy (M), SizePtr, " size" );
446
374
auto *FlagsPtr =
447
- Builder.CreateInBoundsGEP (getEntryTy (M), Entry,
375
+ Builder.CreateInBoundsGEP (offloading:: getEntryTy (M), Entry,
448
376
{ConstantInt::get (getSizeTTy (M), 0 ),
449
377
ConstantInt::get (Type::getInt32Ty (C), 3 )});
450
378
auto *Flags = Builder.CreateLoad (Type::getInt32Ty (C), FlagsPtr, " flag" );
@@ -491,16 +419,16 @@ Function *createRegisterGlobalsFunction(Module &M, bool IsHIP) {
491
419
492
420
Builder.SetInsertPoint (IfEndBB);
493
421
auto *NewEntry = Builder.CreateInBoundsGEP (
494
- getEntryTy (M), Entry, ConstantInt::get (getSizeTTy (M), 1 ));
422
+ offloading:: getEntryTy (M), Entry, ConstantInt::get (getSizeTTy (M), 1 ));
495
423
auto *Cmp = Builder.CreateICmpEQ (
496
424
NewEntry,
497
425
ConstantExpr::getInBoundsGetElementPtr (
498
- ArrayType::get (getEntryTy (M), 0 ), EntriesE,
426
+ ArrayType::get (offloading:: getEntryTy (M), 0 ), EntriesE,
499
427
ArrayRef<Constant *>({ConstantInt::get (getSizeTTy (M), 0 ),
500
428
ConstantInt::get (getSizeTTy (M), 0 )})));
501
429
Entry->addIncoming (
502
430
ConstantExpr::getInBoundsGetElementPtr (
503
- ArrayType::get (getEntryTy (M), 0 ), EntriesB,
431
+ ArrayType::get (offloading:: getEntryTy (M), 0 ), EntriesB,
504
432
ArrayRef<Constant *>({ConstantInt::get (getSizeTTy (M), 0 ),
505
433
ConstantInt::get (getSizeTTy (M), 0 )})),
506
434
&RegGlobalsFn->getEntryBlock ());
0 commit comments