Skip to content

Commit 435addb

Browse files
authored
[NVPTX] Revamp NVVMIntrRange pass (#94422)
Revamp the NVVMIntrRange pass making the following updates: - Use range attributes over range metadata. This is what instcombine has move to for ranges on intrinsics in #88776 and it seems a bit cleaner. - Consider the `!"maxntid{x,y,z}"` and `!"reqntid{x,y,z}"` function metadata when adding ranges for `tid` srge instrinsics. This can allow for smaller ranges and more optimization. - When range attributes are already present, use the intersection of the old and new range. This complements the metadata change by allowing ranges to be shrunk when an intrinsic is in a function which is inlined into a kernel with metadata. While we don't call this more then once yet, we should consider adding a second call after inlining, once this has had a chance to soak for a while and no issues have arisen. I've also re-enabled this pass in the TM, it was disabled years ago due to "numerical discrepancies" https://reviews.llvm.org/D96166. In our testing we haven't seen any issues with adding ranges to intrinsics, and I cannot find any further info about what issues were encountered.
1 parent 31442c9 commit 435addb

File tree

9 files changed

+280
-198
lines changed

9 files changed

+280
-198
lines changed

clang/test/CodeGenCUDA/cuda-builtin-vars.cu

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,21 @@
66
__attribute__((global))
77
void kernel(int *out) {
88
int i = 0;
9-
out[i++] = threadIdx.x; // CHECK: call noundef i32 @llvm.nvvm.read.ptx.sreg.tid.x()
10-
out[i++] = threadIdx.y; // CHECK: call noundef i32 @llvm.nvvm.read.ptx.sreg.tid.y()
11-
out[i++] = threadIdx.z; // CHECK: call noundef i32 @llvm.nvvm.read.ptx.sreg.tid.z()
9+
out[i++] = threadIdx.x; // CHECK: call noundef {{.*}} i32 @llvm.nvvm.read.ptx.sreg.tid.x()
10+
out[i++] = threadIdx.y; // CHECK: call noundef {{.*}} i32 @llvm.nvvm.read.ptx.sreg.tid.y()
11+
out[i++] = threadIdx.z; // CHECK: call noundef {{.*}} i32 @llvm.nvvm.read.ptx.sreg.tid.z()
1212

13-
out[i++] = blockIdx.x; // CHECK: call noundef i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
14-
out[i++] = blockIdx.y; // CHECK: call noundef i32 @llvm.nvvm.read.ptx.sreg.ctaid.y()
15-
out[i++] = blockIdx.z; // CHECK: call noundef i32 @llvm.nvvm.read.ptx.sreg.ctaid.z()
13+
out[i++] = blockIdx.x; // CHECK: call noundef {{.*}} i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
14+
out[i++] = blockIdx.y; // CHECK: call noundef {{.*}} i32 @llvm.nvvm.read.ptx.sreg.ctaid.y()
15+
out[i++] = blockIdx.z; // CHECK: call noundef {{.*}} i32 @llvm.nvvm.read.ptx.sreg.ctaid.z()
1616

17-
out[i++] = blockDim.x; // CHECK: call noundef i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
18-
out[i++] = blockDim.y; // CHECK: call noundef i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
19-
out[i++] = blockDim.z; // CHECK: call noundef i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
17+
out[i++] = blockDim.x; // CHECK: call noundef {{.*}} i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
18+
out[i++] = blockDim.y; // CHECK: call noundef {{.*}} i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
19+
out[i++] = blockDim.z; // CHECK: call noundef {{.*}} i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
2020

21-
out[i++] = gridDim.x; // CHECK: call noundef i32 @llvm.nvvm.read.ptx.sreg.nctaid.x()
22-
out[i++] = gridDim.y; // CHECK: call noundef i32 @llvm.nvvm.read.ptx.sreg.nctaid.y()
23-
out[i++] = gridDim.z; // CHECK: call noundef i32 @llvm.nvvm.read.ptx.sreg.nctaid.z()
21+
out[i++] = gridDim.x; // CHECK: call noundef {{.*}} i32 @llvm.nvvm.read.ptx.sreg.nctaid.x()
22+
out[i++] = gridDim.y; // CHECK: call noundef {{.*}} i32 @llvm.nvvm.read.ptx.sreg.nctaid.y()
23+
out[i++] = gridDim.z; // CHECK: call noundef {{.*}} i32 @llvm.nvvm.read.ptx.sreg.nctaid.z()
2424

2525
out[i++] = warpSize; // CHECK: store i32 32,
2626

llvm/lib/Target/NVPTX/NVPTX.h

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ FunctionPass *createNVPTXISelDag(NVPTXTargetMachine &TM,
4040
ModulePass *createNVPTXAssignValidGlobalNamesPass();
4141
ModulePass *createGenericToNVVMLegacyPass();
4242
ModulePass *createNVPTXCtorDtorLoweringLegacyPass();
43-
FunctionPass *createNVVMIntrRangePass(unsigned int SmVersion);
43+
FunctionPass *createNVVMIntrRangePass();
4444
FunctionPass *createNVVMReflectPass(unsigned int SmVersion);
4545
MachineFunctionPass *createNVPTXPrologEpilogPass();
4646
MachineFunctionPass *createNVPTXReplaceImageHandlesPass();
@@ -53,12 +53,7 @@ MachineFunctionPass *createNVPTXPeephole();
5353
MachineFunctionPass *createNVPTXProxyRegErasurePass();
5454

5555
struct NVVMIntrRangePass : PassInfoMixin<NVVMIntrRangePass> {
56-
NVVMIntrRangePass();
57-
NVVMIntrRangePass(unsigned SmVersion) : SmVersion(SmVersion) {}
5856
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
59-
60-
private:
61-
unsigned SmVersion;
6257
};
6358

6459
struct NVVMReflectPass : PassInfoMixin<NVVMReflectPass> {

llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -542,30 +542,24 @@ void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
542542
// If the NVVM IR has some of reqntid* specified, then output
543543
// the reqntid directive, and set the unspecified ones to 1.
544544
// If none of Reqntid* is specified, don't output reqntid directive.
545-
unsigned Reqntidx, Reqntidy, Reqntidz;
546-
Reqntidx = Reqntidy = Reqntidz = 1;
547-
bool ReqSpecified = false;
548-
ReqSpecified |= getReqNTIDx(F, Reqntidx);
549-
ReqSpecified |= getReqNTIDy(F, Reqntidy);
550-
ReqSpecified |= getReqNTIDz(F, Reqntidz);
545+
std::optional<unsigned> Reqntidx = getReqNTIDx(F);
546+
std::optional<unsigned> Reqntidy = getReqNTIDy(F);
547+
std::optional<unsigned> Reqntidz = getReqNTIDz(F);
551548

552-
if (ReqSpecified)
553-
O << ".reqntid " << Reqntidx << ", " << Reqntidy << ", " << Reqntidz
554-
<< "\n";
549+
if (Reqntidx || Reqntidy || Reqntidz)
550+
O << ".reqntid " << Reqntidx.value_or(1) << ", " << Reqntidy.value_or(1)
551+
<< ", " << Reqntidz.value_or(1) << "\n";
555552

556553
// If the NVVM IR has some of maxntid* specified, then output
557554
// the maxntid directive, and set the unspecified ones to 1.
558555
// If none of maxntid* is specified, don't output maxntid directive.
559-
unsigned Maxntidx, Maxntidy, Maxntidz;
560-
Maxntidx = Maxntidy = Maxntidz = 1;
561-
bool MaxSpecified = false;
562-
MaxSpecified |= getMaxNTIDx(F, Maxntidx);
563-
MaxSpecified |= getMaxNTIDy(F, Maxntidy);
564-
MaxSpecified |= getMaxNTIDz(F, Maxntidz);
565-
566-
if (MaxSpecified)
567-
O << ".maxntid " << Maxntidx << ", " << Maxntidy << ", " << Maxntidz
568-
<< "\n";
556+
std::optional<unsigned> Maxntidx = getMaxNTIDx(F);
557+
std::optional<unsigned> Maxntidy = getMaxNTIDy(F);
558+
std::optional<unsigned> Maxntidz = getMaxNTIDz(F);
559+
560+
if (Maxntidx || Maxntidy || Maxntidz)
561+
O << ".maxntid " << Maxntidx.value_or(1) << ", " << Maxntidy.value_or(1)
562+
<< ", " << Maxntidz.value_or(1) << "\n";
569563

570564
unsigned Mincta = 0;
571565
if (getMinCTASm(F, Mincta))

llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,9 +233,9 @@ void NVPTXTargetMachine::registerPassBuilderCallbacks(
233233
[this](ModulePassManager &PM, OptimizationLevel Level) {
234234
FunctionPassManager FPM;
235235
FPM.addPass(NVVMReflectPass(Subtarget.getSmVersion()));
236-
// FIXME: NVVMIntrRangePass is causing numerical discrepancies,
237-
// investigate and re-enable.
238-
// FPM.addPass(NVVMIntrRangePass(Subtarget.getSmVersion()));
236+
// Note: NVVMIntrRangePass was causing numerical discrepancies at one
237+
// point, if issues crop up, consider disabling.
238+
FPM.addPass(NVVMIntrRangePass());
239239
PM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
240240
});
241241
}

llvm/lib/Target/NVPTX/NVPTXUtilities.cpp

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,14 @@ bool findOneNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
128128
return true;
129129
}
130130

131+
static std::optional<unsigned>
132+
findOneNVVMAnnotation(const GlobalValue &GV, const std::string &PropName) {
133+
unsigned RetVal;
134+
if (findOneNVVMAnnotation(&GV, PropName, RetVal))
135+
return RetVal;
136+
return std::nullopt;
137+
}
138+
131139
bool findAllNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
132140
std::vector<unsigned> &retval) {
133141
auto &AC = getAnnotationCache();
@@ -252,32 +260,57 @@ std::string getSamplerName(const Value &val) {
252260
return std::string(val.getName());
253261
}
254262

255-
bool getMaxNTIDx(const Function &F, unsigned &x) {
256-
return findOneNVVMAnnotation(&F, "maxntidx", x);
263+
std::optional<unsigned> getMaxNTIDx(const Function &F) {
264+
return findOneNVVMAnnotation(F, "maxntidx");
257265
}
258266

259-
bool getMaxNTIDy(const Function &F, unsigned &y) {
260-
return findOneNVVMAnnotation(&F, "maxntidy", y);
267+
std::optional<unsigned> getMaxNTIDy(const Function &F) {
268+
return findOneNVVMAnnotation(F, "maxntidy");
261269
}
262270

263-
bool getMaxNTIDz(const Function &F, unsigned &z) {
264-
return findOneNVVMAnnotation(&F, "maxntidz", z);
271+
std::optional<unsigned> getMaxNTIDz(const Function &F) {
272+
return findOneNVVMAnnotation(F, "maxntidz");
273+
}
274+
275+
std::optional<unsigned> getMaxNTID(const Function &F) {
276+
// Note: The semantics here are a bit strange. The PTX ISA states the
277+
// following (11.4.2. Performance-Tuning Directives: .maxntid):
278+
//
279+
// Note that this directive guarantees that the total number of threads does
280+
// not exceed the maximum, but does not guarantee that the limit in any
281+
// particular dimension is not exceeded.
282+
std::optional<unsigned> MaxNTIDx = getMaxNTIDx(F);
283+
std::optional<unsigned> MaxNTIDy = getMaxNTIDy(F);
284+
std::optional<unsigned> MaxNTIDz = getMaxNTIDz(F);
285+
if (MaxNTIDx || MaxNTIDy || MaxNTIDz)
286+
return MaxNTIDx.value_or(1) * MaxNTIDy.value_or(1) * MaxNTIDz.value_or(1);
287+
return std::nullopt;
265288
}
266289

267290
bool getMaxClusterRank(const Function &F, unsigned &x) {
268291
return findOneNVVMAnnotation(&F, "maxclusterrank", x);
269292
}
270293

271-
bool getReqNTIDx(const Function &F, unsigned &x) {
272-
return findOneNVVMAnnotation(&F, "reqntidx", x);
294+
std::optional<unsigned> getReqNTIDx(const Function &F) {
295+
return findOneNVVMAnnotation(F, "reqntidx");
296+
}
297+
298+
std::optional<unsigned> getReqNTIDy(const Function &F) {
299+
return findOneNVVMAnnotation(F, "reqntidy");
273300
}
274301

275-
bool getReqNTIDy(const Function &F, unsigned &y) {
276-
return findOneNVVMAnnotation(&F, "reqntidy", y);
302+
std::optional<unsigned> getReqNTIDz(const Function &F) {
303+
return findOneNVVMAnnotation(F, "reqntidz");
277304
}
278305

279-
bool getReqNTIDz(const Function &F, unsigned &z) {
280-
return findOneNVVMAnnotation(&F, "reqntidz", z);
306+
std::optional<unsigned> getReqNTID(const Function &F) {
307+
// Note: The semantics here are a bit strange. See getMaxNTID.
308+
std::optional<unsigned> ReqNTIDx = getReqNTIDx(F);
309+
std::optional<unsigned> ReqNTIDy = getReqNTIDy(F);
310+
std::optional<unsigned> ReqNTIDz = getReqNTIDz(F);
311+
if (ReqNTIDx || ReqNTIDy || ReqNTIDz)
312+
return ReqNTIDx.value_or(1) * ReqNTIDy.value_or(1) * ReqNTIDz.value_or(1);
313+
return std::nullopt;
281314
}
282315

283316
bool getMinCTASm(const Function &F, unsigned &x) {

llvm/lib/Target/NVPTX/NVPTXUtilities.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,15 @@ std::string getTextureName(const Value &);
4848
std::string getSurfaceName(const Value &);
4949
std::string getSamplerName(const Value &);
5050

51-
bool getMaxNTIDx(const Function &, unsigned &);
52-
bool getMaxNTIDy(const Function &, unsigned &);
53-
bool getMaxNTIDz(const Function &, unsigned &);
54-
55-
bool getReqNTIDx(const Function &, unsigned &);
56-
bool getReqNTIDy(const Function &, unsigned &);
57-
bool getReqNTIDz(const Function &, unsigned &);
51+
std::optional<unsigned> getMaxNTIDx(const Function &);
52+
std::optional<unsigned> getMaxNTIDy(const Function &);
53+
std::optional<unsigned> getMaxNTIDz(const Function &);
54+
std::optional<unsigned> getMaxNTID(const Function &F);
55+
56+
std::optional<unsigned> getReqNTIDx(const Function &);
57+
std::optional<unsigned> getReqNTIDy(const Function &);
58+
std::optional<unsigned> getReqNTIDz(const Function &);
59+
std::optional<unsigned> getReqNTID(const Function &);
5860

5961
bool getMaxClusterRank(const Function &, unsigned &);
6062
bool getMinCTASm(const Function &, unsigned &);

0 commit comments

Comments
 (0)