12
12
// http://docs.nvidia.com/cuda/parallel-thread-execution/#state-spaces
13
13
//
14
14
// Kernel parameters are read-only and accessible only via ld.param
15
- // instruction, directly or via a pointer. Pointers to kernel
16
- // arguments can't be converted to generic address space.
15
+ // instruction, directly or via a pointer.
17
16
//
18
17
// Device function parameters are directly accessible via
19
18
// ld.param/st.param, but taking the address of one returns a pointer
54
53
// ...
55
54
// }
56
55
//
57
- // 2. Convert pointers in a byval kernel parameter to pointers in the global
58
- // address space. As #2, it allows NVPTX to emit more ld/st.global. E.g.,
56
+ // 2. Convert byval kernel parameters to pointers in the param address space
57
+ // (so that NVPTX emits ld/st.param). Convert pointers *within* a byval
58
+ // kernel parameter to pointers in the global address space. This allows
59
+ // NVPTX to emit ld/st.global.
59
60
//
60
61
// struct S {
61
62
// int *x;
68
69
//
69
70
// "b" points to the global address space. In the IR level,
70
71
//
71
- // define void @foo({i32*, i32*}* byval %input) {
72
- // %b_ptr = getelementptr {i32*, i32* }, {i32*, i32*}* %input, i64 0, i32 1
73
- // %b = load i32*, i32** %b_ptr
72
+ // define void @foo(ptr byval %input) {
73
+ // %b_ptr = getelementptr {ptr, ptr }, ptr %input, i64 0, i32 1
74
+ // %b = load ptr, ptr %b_ptr
74
75
// ; use %b
75
76
// }
76
77
//
77
78
// becomes
78
79
//
79
80
// define void @foo({i32*, i32*}* byval %input) {
80
- // %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1
81
- // %b = load i32*, i32** %b_ptr
82
- // %b_global = addrspacecast i32* %b to i32 addrspace(1)*
83
- // %b_generic = addrspacecast i32 addrspace(1)* %b_global to i32*
81
+ // %b_param = addrspacecat ptr %input to ptr addrspace(101)
82
+ // %b_ptr = getelementptr {ptr, ptr}, ptr addrspace(101) %b_param, i64 0, i32 1
83
+ // %b = load ptr, ptr addrspace(101) %b_ptr
84
+ // %b_global = addrspacecast ptr %b to ptr addrspace(1)
84
85
// ; use %b_generic
85
86
// }
86
87
//
88
+ // Create a local copy of kernel byval parameters used in a way that *might* mutate
89
+ // the parameter, by storing it in an alloca. Mutations to "grid_constant" parameters
90
+ // are undefined behaviour, and don't require local copies.
91
+ //
92
+ // define void @foo(ptr byval(%struct.s) align 4 %input) {
93
+ // store i32 42, ptr %input
94
+ // ret void
95
+ // }
96
+ //
97
+ // becomes
98
+ //
99
+ // define void @foo(ptr byval(%struct.s) align 4 %input) #1 {
100
+ // %input1 = alloca %struct.s, align 4
101
+ // %input2 = addrspacecast ptr %input to ptr addrspace(101)
102
+ // %input3 = load %struct.s, ptr addrspace(101) %input2, align 4
103
+ // store %struct.s %input3, ptr %input1, align 4
104
+ // store i32 42, ptr %input1, align 4
105
+ // ret void
106
+ // }
107
+ //
108
+ // If %input were passed to a device function, or written to memory,
109
+ // conservatively assume that %input gets mutated, and create a local copy.
110
+ //
111
+ // Convert param pointers to grid_constant byval kernel parameters that are
112
+ // passed into calls (device functions, intrinsics, inline asm), or otherwise
113
+ // "escape" (into stores/ptrtoints) to the generic address space, using the
114
+ // `nvvm.ptr.param.to.gen` intrinsic, so that NVPTX emits cvta.param
115
+ // (available for sm70+)
116
+ //
117
+ // define void @foo(ptr byval(%struct.s) %input) {
118
+ // ; %input is a grid_constant
119
+ // %call = call i32 @escape(ptr %input)
120
+ // ret void
121
+ // }
122
+ //
123
+ // becomes
124
+ //
125
+ // define void @foo(ptr byval(%struct.s) %input) {
126
+ // %input1 = addrspacecast ptr %input to ptr addrspace(101)
127
+ // ; the following intrinsic converts pointer to generic. We don't use an addrspacecast
128
+ // ; to prevent generic -> param -> generic from getting cancelled out
129
+ // %input1.gen = call ptr @llvm.nvvm.ptr.param.to.gen.p0.p101(ptr addrspace(101) %input1)
130
+ // %call = call i32 @escape(ptr %input1.gen)
131
+ // ret void
132
+ // }
133
+ //
87
134
// TODO: merge this pass with NVPTXInferAddressSpaces so that other passes don't
88
135
// cancel the addrspacecast pair this pass emits.
89
136
// ===----------------------------------------------------------------------===//
@@ -166,19 +213,22 @@ INITIALIZE_PASS_END(NVPTXLowerArgs, "nvptx-lower-args",
166
213
// ones in parameter AS, so we can access them using ld.param.
167
214
// =============================================================================
168
215
169
- // Replaces the \p OldUser instruction with the same in parameter AS.
170
- // Only Load and GEP are supported.
171
- static void convertToParamAS(Value *OldUser, Value *Param) {
172
- Instruction *I = dyn_cast<Instruction>(OldUser);
173
- assert (I && " OldUser must be an instruction" );
216
+ // For Loads, replaces the \p OldUse of the pointer with a Use of the same
217
+ // pointer in parameter AS.
218
+ // For "escapes" (to memory, a function call, or a ptrtoint), cast the OldUse to
219
+ // generic using cvta.param.
220
+ static void convertToParamAS(Use *OldUse, Value *Param, bool GridConstant) {
221
+ Instruction *I = dyn_cast<Instruction>(OldUse->getUser ());
222
+ assert (I && " OldUse must be in an instruction" );
174
223
struct IP {
224
+ Use *OldUse;
175
225
Instruction *OldInstruction;
176
226
Value *NewParam;
177
227
};
178
- SmallVector<IP> ItemsToConvert = {{I, Param}};
228
+ SmallVector<IP> ItemsToConvert = {{OldUse, I, Param}};
179
229
SmallVector<Instruction *> InstructionsToDelete;
180
230
181
- auto CloneInstInParamAS = [](const IP &I) -> Value * {
231
+ auto CloneInstInParamAS = [GridConstant ](const IP &I) -> Value * {
182
232
if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction )) {
183
233
LI->setOperand (0 , I.NewParam );
184
234
return LI;
@@ -202,6 +252,43 @@ static void convertToParamAS(Value *OldUser, Value *Param) {
202
252
// Just pass through the argument, the old ASC is no longer needed.
203
253
return I.NewParam ;
204
254
}
255
+
256
+ if (GridConstant) {
257
+ auto GetParamAddrCastToGeneric =
258
+ [](Value *Addr, Instruction *OriginalUser) -> Value * {
259
+ PointerType *ReturnTy =
260
+ PointerType::get (OriginalUser->getContext (), ADDRESS_SPACE_GENERIC);
261
+ Function *CvtToGen = Intrinsic::getDeclaration (
262
+ OriginalUser->getModule (), Intrinsic::nvvm_ptr_param_to_gen,
263
+ {ReturnTy, PointerType::get (OriginalUser->getContext (),
264
+ ADDRESS_SPACE_PARAM)});
265
+
266
+ // Cast param address to generic address space
267
+ Value *CvtToGenCall =
268
+ CallInst::Create (CvtToGen, Addr, Addr->getName () + " .gen" ,
269
+ OriginalUser->getIterator ());
270
+ return CvtToGenCall;
271
+ };
272
+
273
+ if (auto *CI = dyn_cast<CallInst>(I.OldInstruction )) {
274
+ I.OldUse ->set (GetParamAddrCastToGeneric (I.NewParam , CI));
275
+ return CI;
276
+ }
277
+ if (auto *SI = dyn_cast<StoreInst>(I.OldInstruction )) {
278
+ // byval address is being stored, cast it to generic
279
+ if (SI->getValueOperand () == I.OldUse ->get ())
280
+ SI->setOperand (0 , GetParamAddrCastToGeneric (I.NewParam , SI));
281
+ return SI;
282
+ }
283
+ if (auto *PI = dyn_cast<PtrToIntInst>(I.OldInstruction )) {
284
+ if (PI->getPointerOperand () == I.OldUse ->get ())
285
+ PI->setOperand (0 , GetParamAddrCastToGeneric (I.NewParam , PI));
286
+ return PI;
287
+ }
288
+ llvm_unreachable (
289
+ " Instruction unsupported even for grid_constant argument" );
290
+ }
291
+
205
292
llvm_unreachable (" Unsupported instruction" );
206
293
};
207
294
@@ -213,8 +300,8 @@ static void convertToParamAS(Value *OldUser, Value *Param) {
213
300
// We've created a new instruction. Queue users of the old instruction to
214
301
// be converted and the instruction itself to be deleted. We can't delete
215
302
// the old instruction yet, because it's still in use by a load somewhere.
216
- for (Value *V : I.OldInstruction ->users ())
217
- ItemsToConvert.push_back ({cast<Instruction>(V ), NewInst});
303
+ for (Use &U : I.OldInstruction ->uses ())
304
+ ItemsToConvert.push_back ({&U, cast<Instruction>(U. getUser () ), NewInst});
218
305
219
306
InstructionsToDelete.push_back (I.OldInstruction );
220
307
}
@@ -272,6 +359,7 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
272
359
SmallVector<Load> Loads;
273
360
std::queue<LoadContext> Worklist;
274
361
Worklist.push ({ArgInParamAS, 0 });
362
+ bool IsGridConstant = isParamGridConstant (*Arg);
275
363
276
364
while (!Worklist.empty ()) {
277
365
LoadContext Ctx = Worklist.front ();
@@ -303,8 +391,14 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
303
391
continue ;
304
392
}
305
393
394
+ // supported for grid_constant
395
+ if (IsGridConstant &&
396
+ (isa<CallInst>(CurUser) || isa<StoreInst>(CurUser) ||
397
+ isa<PtrToIntInst>(CurUser)))
398
+ continue ;
399
+
306
400
llvm_unreachable (" All users must be one of: load, "
307
- " bitcast, getelementptr. " );
401
+ " bitcast, getelementptr, call, store, ptrtoint " );
308
402
}
309
403
}
310
404
@@ -317,49 +411,59 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
317
411
318
412
void NVPTXLowerArgs::handleByValParam (const NVPTXTargetMachine &TM,
319
413
Argument *Arg) {
414
+ bool IsGridConstant = isParamGridConstant (*Arg);
320
415
Function *Func = Arg->getParent ();
321
416
BasicBlock::iterator FirstInst = Func->getEntryBlock ().begin ();
322
417
Type *StructType = Arg->getParamByValType ();
323
418
assert (StructType && " Missing byval type" );
324
419
325
- auto IsALoadChain = [&](Value *Start) {
420
+ auto AreSupportedUsers = [&](Value *Start) {
326
421
SmallVector<Value *, 16 > ValuesToCheck = {Start};
327
- auto IsALoadChainInstr = [](Value *V) -> bool {
422
+ auto IsSupportedUse = [IsGridConstant ](Value *V) -> bool {
328
423
if (isa<GetElementPtrInst>(V) || isa<BitCastInst>(V) || isa<LoadInst>(V))
329
424
return true ;
330
425
// ASC to param space are OK, too -- we'll just strip them.
331
426
if (auto *ASC = dyn_cast<AddrSpaceCastInst>(V)) {
332
427
if (ASC->getDestAddressSpace () == ADDRESS_SPACE_PARAM)
333
428
return true ;
334
429
}
430
+ // Simple calls and stores are supported for grid_constants
431
+ // writes to these pointers are undefined behaviour
432
+ if (IsGridConstant &&
433
+ (isa<CallInst>(V) || isa<StoreInst>(V) || isa<PtrToIntInst>(V)))
434
+ return true ;
335
435
return false ;
336
436
};
337
437
338
438
while (!ValuesToCheck.empty ()) {
339
439
Value *V = ValuesToCheck.pop_back_val ();
340
- if (!IsALoadChainInstr (V)) {
440
+ if (!IsSupportedUse (V)) {
341
441
LLVM_DEBUG (dbgs () << " Need a "
342
442
<< (isParamGridConstant (*Arg) ? " cast " : " copy " )
343
443
<< " of " << *Arg << " because of " << *V << " \n " );
344
444
(void )Arg;
345
445
return false ;
346
446
}
347
- if (!isa<LoadInst>(V))
447
+ if (!isa<LoadInst>(V) && !isa<CallInst>(V) && !isa<StoreInst>(V) &&
448
+ !isa<PtrToIntInst>(V))
348
449
llvm::append_range (ValuesToCheck, V->users ());
349
450
}
350
451
return true ;
351
452
};
352
453
353
- if (llvm::all_of (Arg->users (), IsALoadChain )) {
454
+ if (llvm::all_of (Arg->users (), AreSupportedUsers )) {
354
455
// Convert all loads and intermediate operations to use parameter AS and
355
456
// skip creation of a local copy of the argument.
356
- SmallVector<User *, 16 > UsersToUpdate (Arg->users ());
457
+ SmallVector<Use *, 16 > UsesToUpdate;
458
+ for (Use &U : Arg->uses ())
459
+ UsesToUpdate.push_back (&U);
460
+
357
461
Value *ArgInParamAS = new AddrSpaceCastInst (
358
462
Arg, PointerType::get (StructType, ADDRESS_SPACE_PARAM), Arg->getName (),
359
463
FirstInst);
360
- for (Value *V : UsersToUpdate )
361
- convertToParamAS (V , ArgInParamAS);
362
- LLVM_DEBUG (dbgs () << " No need to copy " << *Arg << " \n " );
464
+ for (Use *U : UsesToUpdate )
465
+ convertToParamAS (U , ArgInParamAS, IsGridConstant );
466
+ LLVM_DEBUG (dbgs () << " No need to copy or cast " << *Arg << " \n " );
363
467
364
468
const auto *TLI =
365
469
cast<NVPTXTargetLowering>(TM.getSubtargetImpl ()->getTargetLowering ());
@@ -376,16 +480,11 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
376
480
// temporary copy. When a pointer might have escaped, conservatively replace
377
481
// all of its uses (which might include a device function call) with a cast
378
482
// to the generic address space.
379
- // TODO: only cast byval grid constant parameters at use points that need
380
- // generic address (e.g., merging parameter pointers with other address
381
- // space, or escaping to call-sites, inline-asm, memory), and use the
382
- // parameter address space for normal loads.
383
483
IRBuilder<> IRB (&Func->getEntryBlock ().front ());
384
484
385
485
// Cast argument to param address space
386
- auto *CastToParam =
387
- cast<AddrSpaceCastInst>(IRB.CreateAddrSpaceCast (
388
- Arg, IRB.getPtrTy (ADDRESS_SPACE_PARAM), Arg->getName () + " .param" ));
486
+ auto *CastToParam = cast<AddrSpaceCastInst>(IRB.CreateAddrSpaceCast (
487
+ Arg, IRB.getPtrTy (ADDRESS_SPACE_PARAM), Arg->getName () + " .param" ));
389
488
390
489
// Cast param address to generic address space. We do not use an
391
490
// addrspacecast to generic here, because, LLVM considers `Arg` to be in the
0 commit comments