Skip to content

Commit 888d970

Browse files
committed
[NVPTX] Revamp NVVMIntrRange pass
1 parent 7db4e6c commit 888d970

File tree

7 files changed

+237
-155
lines changed

7 files changed

+237
-155
lines changed

llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -542,30 +542,24 @@ void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
542542
// If the NVVM IR has some of reqntid* specified, then output
543543
// the reqntid directive, and set the unspecified ones to 1.
544544
// If none of Reqntid* is specified, don't output reqntid directive.
545-
unsigned Reqntidx, Reqntidy, Reqntidz;
546-
Reqntidx = Reqntidy = Reqntidz = 1;
547-
bool ReqSpecified = false;
548-
ReqSpecified |= getReqNTIDx(F, Reqntidx);
549-
ReqSpecified |= getReqNTIDy(F, Reqntidy);
550-
ReqSpecified |= getReqNTIDz(F, Reqntidz);
545+
std::optional<unsigned> Reqntidx = getReqNTIDx(F);
546+
std::optional<unsigned> Reqntidy = getReqNTIDy(F);
547+
std::optional<unsigned> Reqntidz = getReqNTIDz(F);
551548

552-
if (ReqSpecified)
553-
O << ".reqntid " << Reqntidx << ", " << Reqntidy << ", " << Reqntidz
554-
<< "\n";
549+
if (Reqntidx || Reqntidy || Reqntidz)
550+
O << ".reqntid " << Reqntidx.value_or(1) << ", " << Reqntidy.value_or(1)
551+
<< ", " << Reqntidz.value_or(1) << "\n";
555552

556553
// If the NVVM IR has some of maxntid* specified, then output
557554
// the maxntid directive, and set the unspecified ones to 1.
558555
// If none of maxntid* is specified, don't output maxntid directive.
559-
unsigned Maxntidx, Maxntidy, Maxntidz;
560-
Maxntidx = Maxntidy = Maxntidz = 1;
561-
bool MaxSpecified = false;
562-
MaxSpecified |= getMaxNTIDx(F, Maxntidx);
563-
MaxSpecified |= getMaxNTIDy(F, Maxntidy);
564-
MaxSpecified |= getMaxNTIDz(F, Maxntidz);
565-
566-
if (MaxSpecified)
567-
O << ".maxntid " << Maxntidx << ", " << Maxntidy << ", " << Maxntidz
568-
<< "\n";
556+
std::optional<unsigned> Maxntidx = getMaxNTIDx(F);
557+
std::optional<unsigned> Maxntidy = getMaxNTIDy(F);
558+
std::optional<unsigned> Maxntidz = getMaxNTIDz(F);
559+
560+
if (Maxntidx || Maxntidy || Maxntidz)
561+
O << ".maxntid " << Maxntidx.value_or(1) << ", " << Maxntidy.value_or(1)
562+
<< ", " << Maxntidz.value_or(1) << "\n";
569563

570564
unsigned Mincta = 0;
571565
if (getMinCTASm(F, Mincta))

llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,9 +233,9 @@ void NVPTXTargetMachine::registerPassBuilderCallbacks(
233233
[this](ModulePassManager &PM, OptimizationLevel Level) {
234234
FunctionPassManager FPM;
235235
FPM.addPass(NVVMReflectPass(Subtarget.getSmVersion()));
236-
// FIXME: NVVMIntrRangePass is causing numerical discrepancies,
237-
// investigate and re-enable.
238-
// FPM.addPass(NVVMIntrRangePass(Subtarget.getSmVersion()));
236+
// Note: NVVMIntrRangePass was causing numerical discrepancies at one
237+
// point, if issues crop up, consider disabling.
238+
FPM.addPass(NVVMIntrRangePass(Subtarget.getSmVersion()));
239239
PM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
240240
});
241241
}

llvm/lib/Target/NVPTX/NVPTXUtilities.cpp

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,15 @@ bool findOneNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
128128
return true;
129129
}
130130

131+
static std::optional<unsigned>
132+
findOneNVVMAnnotation(const GlobalValue &GV, const std::string &PropName) {
133+
unsigned RetVal;
134+
bool Found = findOneNVVMAnnotation(&GV, PropName, RetVal);
135+
if (Found)
136+
return RetVal;
137+
return std::nullopt;
138+
}
139+
131140
bool findAllNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
132141
std::vector<unsigned> &retval) {
133142
auto &AC = getAnnotationCache();
@@ -252,32 +261,57 @@ std::string getSamplerName(const Value &val) {
252261
return std::string(val.getName());
253262
}
254263

255-
bool getMaxNTIDx(const Function &F, unsigned &x) {
256-
return findOneNVVMAnnotation(&F, "maxntidx", x);
264+
std::optional<unsigned> getMaxNTIDx(const Function &F) {
265+
return findOneNVVMAnnotation(F, "maxntidx");
257266
}
258267

259-
bool getMaxNTIDy(const Function &F, unsigned &y) {
260-
return findOneNVVMAnnotation(&F, "maxntidy", y);
268+
std::optional<unsigned> getMaxNTIDy(const Function &F) {
269+
return findOneNVVMAnnotation(F, "maxntidy");
261270
}
262271

263-
bool getMaxNTIDz(const Function &F, unsigned &z) {
264-
return findOneNVVMAnnotation(&F, "maxntidz", z);
272+
std::optional<unsigned> getMaxNTIDz(const Function &F) {
273+
return findOneNVVMAnnotation(F, "maxntidz");
274+
}
275+
276+
std::optional<unsigned> getMaxNTID(const Function &F) {
277+
// Note: The semantics here are a bit strange. The PTX ISA states the
278+
// following (11.4.2. Performance-Tuning Directives: .maxntid):
279+
//
280+
// Note that this directive guarantees that the total number of threads does
281+
// not exceed the maximum, but does not guarantee that the limit in any
282+
// particular dimension is not exceeded.
283+
std::optional<unsigned> MaxNTIDx = getMaxNTIDx(F);
284+
std::optional<unsigned> MaxNTIDy = getMaxNTIDy(F);
285+
std::optional<unsigned> MaxNTIDz = getMaxNTIDz(F);
286+
if (MaxNTIDx || MaxNTIDy || MaxNTIDz)
287+
return MaxNTIDx.value_or(1) * MaxNTIDy.value_or(1) * MaxNTIDz.value_or(1);
288+
return std::nullopt;
265289
}
266290

267291
bool getMaxClusterRank(const Function &F, unsigned &x) {
268292
return findOneNVVMAnnotation(&F, "maxclusterrank", x);
269293
}
270294

271-
bool getReqNTIDx(const Function &F, unsigned &x) {
272-
return findOneNVVMAnnotation(&F, "reqntidx", x);
295+
std::optional<unsigned> getReqNTIDx(const Function &F) {
296+
return findOneNVVMAnnotation(F, "reqntidx");
297+
}
298+
299+
std::optional<unsigned> getReqNTIDy(const Function &F) {
300+
return findOneNVVMAnnotation(F, "reqntidy");
273301
}
274302

275-
bool getReqNTIDy(const Function &F, unsigned &y) {
276-
return findOneNVVMAnnotation(&F, "reqntidy", y);
303+
std::optional<unsigned> getReqNTIDz(const Function &F) {
304+
return findOneNVVMAnnotation(F, "reqntidz");
277305
}
278306

279-
bool getReqNTIDz(const Function &F, unsigned &z) {
280-
return findOneNVVMAnnotation(&F, "reqntidz", z);
307+
std::optional<unsigned> getReqNTID(const Function &F) {
308+
// Note: The semantics here are a bit strange. See getMaxNTID.
309+
std::optional<unsigned> ReqNTIDx = getReqNTIDx(F);
310+
std::optional<unsigned> ReqNTIDy = getReqNTIDy(F);
311+
std::optional<unsigned> ReqNTIDz = getReqNTIDz(F);
312+
if (ReqNTIDx || ReqNTIDy || ReqNTIDz)
313+
return ReqNTIDx.value_or(1) * ReqNTIDy.value_or(1) * ReqNTIDz.value_or(1);
314+
return std::nullopt;
281315
}
282316

283317
bool getMinCTASm(const Function &F, unsigned &x) {

llvm/lib/Target/NVPTX/NVPTXUtilities.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,15 @@ std::string getTextureName(const Value &);
4848
std::string getSurfaceName(const Value &);
4949
std::string getSamplerName(const Value &);
5050

51-
bool getMaxNTIDx(const Function &, unsigned &);
52-
bool getMaxNTIDy(const Function &, unsigned &);
53-
bool getMaxNTIDz(const Function &, unsigned &);
54-
55-
bool getReqNTIDx(const Function &, unsigned &);
56-
bool getReqNTIDy(const Function &, unsigned &);
57-
bool getReqNTIDz(const Function &, unsigned &);
51+
std::optional<unsigned> getMaxNTIDx(const Function &);
52+
std::optional<unsigned> getMaxNTIDy(const Function &);
53+
std::optional<unsigned> getMaxNTIDz(const Function &);
54+
std::optional<unsigned> getMaxNTID(const Function &F);
55+
56+
std::optional<unsigned> getReqNTIDx(const Function &);
57+
std::optional<unsigned> getReqNTIDy(const Function &);
58+
std::optional<unsigned> getReqNTIDz(const Function &);
59+
std::optional<unsigned> getReqNTID(const Function &);
5860

5961
bool getMaxClusterRank(const Function &, unsigned &);
6062
bool getMinCTASm(const Function &, unsigned &);

0 commit comments

Comments
 (0)