Skip to content

[NVPTX][NFC] Refactor utilities to use std::optional #109883

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -563,21 +563,19 @@ void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
O << ".maxntid " << Maxntidx.value_or(1) << ", " << Maxntidy.value_or(1)
<< ", " << Maxntidz.value_or(1) << "\n";

unsigned Mincta = 0;
if (getMinCTASm(F, Mincta))
O << ".minnctapersm " << Mincta << "\n";
if (const auto Mincta = getMinCTASm(F))
O << ".minnctapersm " << *Mincta << "\n";

unsigned Maxnreg = 0;
if (getMaxNReg(F, Maxnreg))
O << ".maxnreg " << Maxnreg << "\n";
if (const auto Maxnreg = getMaxNReg(F))
O << ".maxnreg " << *Maxnreg << "\n";

// .maxclusterrank directive requires SM_90 or higher, make sure that we
// filter it out for lower SM versions, as it causes a hard ptxas crash.
const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
const auto *STI = static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
unsigned Maxclusterrank = 0;
if (getMaxClusterRank(F, Maxclusterrank) && STI->getSmVersion() >= 90)
O << ".maxclusterrank " << Maxclusterrank << "\n";
if (STI->getSmVersion() >= 90)
if (const auto Maxclusterrank = getMaxClusterRank(F))
O << ".maxclusterrank " << *Maxclusterrank << "\n";
}

std::string NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {
Expand Down
139 changes: 54 additions & 85 deletions llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "NVPTXUtilities.h"
#include "NVPTX.h"
#include "NVPTXTargetMachine.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalVariable.h"
Expand Down Expand Up @@ -130,8 +131,8 @@ static void cacheAnnotationFromMD(const Module *m, const GlobalValue *gv) {
}
}

bool findOneNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
unsigned &retval) {
static std::optional<unsigned> findOneNVVMAnnotation(const GlobalValue *gv,
const std::string &prop) {
auto &AC = getAnnotationCache();
std::lock_guard<sys::Mutex> Guard(AC.Lock);
const Module *m = gv->getParent();
Expand All @@ -140,21 +141,13 @@ bool findOneNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
else if (AC.Cache[m].find(gv) == AC.Cache[m].end())
cacheAnnotationFromMD(m, gv);
if (AC.Cache[m][gv].find(prop) == AC.Cache[m][gv].end())
return false;
retval = AC.Cache[m][gv][prop][0];
return true;
}

static std::optional<unsigned>
findOneNVVMAnnotation(const GlobalValue &GV, const std::string &PropName) {
unsigned RetVal;
if (findOneNVVMAnnotation(&GV, PropName, RetVal))
return RetVal;
return std::nullopt;
return std::nullopt;
return AC.Cache[m][gv][prop][0];
}

bool findAllNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
std::vector<unsigned> &retval) {
static bool findAllNVVMAnnotation(const GlobalValue *gv,
const std::string &prop,
std::vector<unsigned> &retval) {
auto &AC = getAnnotationCache();
std::lock_guard<sys::Mutex> Guard(AC.Lock);
const Module *m = gv->getParent();
Expand All @@ -168,25 +161,13 @@ bool findAllNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
return true;
}

bool isTexture(const Value &val) {
if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
unsigned Annot;
if (findOneNVVMAnnotation(gv, "texture", Annot)) {
assert((Annot == 1) && "Unexpected annotation on a texture symbol");
static bool globalHasNVVMAnnotation(const Value &V, const std::string &Prop) {
if (const auto *GV = dyn_cast<GlobalValue>(&V))
if (const auto Annot = findOneNVVMAnnotation(GV, Prop)) {
assert((*Annot == 1) && "Unexpected annotation on a symbol");
return true;
}
}
return false;
}

bool isSurface(const Value &val) {
if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
unsigned Annot;
if (findOneNVVMAnnotation(gv, "surface", Annot)) {
assert((Annot == 1) && "Unexpected annotation on a surface symbol");
return true;
}
}
return false;
}

Expand Down Expand Up @@ -220,71 +201,60 @@ bool isParamGridConstant(const Value &V) {
return false;
}

bool isSampler(const Value &val) {
bool isTexture(const Value &V) { return globalHasNVVMAnnotation(V, "texture"); }

bool isSurface(const Value &V) { return globalHasNVVMAnnotation(V, "surface"); }

bool isSampler(const Value &V) {
const char *AnnotationName = "sampler";

if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
unsigned Annot;
if (findOneNVVMAnnotation(gv, AnnotationName, Annot)) {
assert((Annot == 1) && "Unexpected annotation on a sampler symbol");
return true;
}
}
return argHasNVVMAnnotation(val, AnnotationName);
return globalHasNVVMAnnotation(V, AnnotationName) ||
argHasNVVMAnnotation(V, AnnotationName);
}

bool isImageReadOnly(const Value &val) {
return argHasNVVMAnnotation(val, "rdoimage");
bool isImageReadOnly(const Value &V) {
return argHasNVVMAnnotation(V, "rdoimage");
}

bool isImageWriteOnly(const Value &val) {
return argHasNVVMAnnotation(val, "wroimage");
bool isImageWriteOnly(const Value &V) {
return argHasNVVMAnnotation(V, "wroimage");
}

bool isImageReadWrite(const Value &val) {
return argHasNVVMAnnotation(val, "rdwrimage");
bool isImageReadWrite(const Value &V) {
return argHasNVVMAnnotation(V, "rdwrimage");
}

bool isImage(const Value &val) {
return isImageReadOnly(val) || isImageWriteOnly(val) || isImageReadWrite(val);
bool isImage(const Value &V) {
return isImageReadOnly(V) || isImageWriteOnly(V) || isImageReadWrite(V);
}

bool isManaged(const Value &val) {
if(const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
unsigned Annot;
if (findOneNVVMAnnotation(gv, "managed", Annot)) {
assert((Annot == 1) && "Unexpected annotation on a managed symbol");
return true;
}
}
return false;
}
bool isManaged(const Value &V) { return globalHasNVVMAnnotation(V, "managed"); }

std::string getTextureName(const Value &val) {
assert(val.hasName() && "Found texture variable with no name");
return std::string(val.getName());
StringRef getTextureName(const Value &V) {
assert(V.hasName() && "Found texture variable with no name");
return V.getName();
}

std::string getSurfaceName(const Value &val) {
assert(val.hasName() && "Found surface variable with no name");
return std::string(val.getName());
StringRef getSurfaceName(const Value &V) {
assert(V.hasName() && "Found surface variable with no name");
return V.getName();
}

std::string getSamplerName(const Value &val) {
assert(val.hasName() && "Found sampler variable with no name");
return std::string(val.getName());
StringRef getSamplerName(const Value &V) {
assert(V.hasName() && "Found sampler variable with no name");
return V.getName();
}

std::optional<unsigned> getMaxNTIDx(const Function &F) {
return findOneNVVMAnnotation(F, "maxntidx");
return findOneNVVMAnnotation(&F, "maxntidx");
}

std::optional<unsigned> getMaxNTIDy(const Function &F) {
return findOneNVVMAnnotation(F, "maxntidy");
return findOneNVVMAnnotation(&F, "maxntidy");
}

std::optional<unsigned> getMaxNTIDz(const Function &F) {
return findOneNVVMAnnotation(F, "maxntidz");
return findOneNVVMAnnotation(&F, "maxntidz");
}

std::optional<unsigned> getMaxNTID(const Function &F) {
Expand All @@ -302,20 +272,20 @@ std::optional<unsigned> getMaxNTID(const Function &F) {
return std::nullopt;
}

bool getMaxClusterRank(const Function &F, unsigned &x) {
return findOneNVVMAnnotation(&F, "maxclusterrank", x);
std::optional<unsigned> getMaxClusterRank(const Function &F) {
return findOneNVVMAnnotation(&F, "maxclusterrank");
}

std::optional<unsigned> getReqNTIDx(const Function &F) {
return findOneNVVMAnnotation(F, "reqntidx");
return findOneNVVMAnnotation(&F, "reqntidx");
}

std::optional<unsigned> getReqNTIDy(const Function &F) {
return findOneNVVMAnnotation(F, "reqntidy");
return findOneNVVMAnnotation(&F, "reqntidy");
}

std::optional<unsigned> getReqNTIDz(const Function &F) {
return findOneNVVMAnnotation(F, "reqntidz");
return findOneNVVMAnnotation(&F, "reqntidz");
}

std::optional<unsigned> getReqNTID(const Function &F) {
Expand All @@ -328,21 +298,20 @@ std::optional<unsigned> getReqNTID(const Function &F) {
return std::nullopt;
}

bool getMinCTASm(const Function &F, unsigned &x) {
return findOneNVVMAnnotation(&F, "minctasm", x);
std::optional<unsigned> getMinCTASm(const Function &F) {
return findOneNVVMAnnotation(&F, "minctasm");
}

bool getMaxNReg(const Function &F, unsigned &x) {
return findOneNVVMAnnotation(&F, "maxnreg", x);
std::optional<unsigned> getMaxNReg(const Function &F) {
return findOneNVVMAnnotation(&F, "maxnreg");
}

bool isKernelFunction(const Function &F) {
unsigned x = 0;
if (!findOneNVVMAnnotation(&F, "kernel", x)) {
// There is no NVVM metadata, check the calling convention
return F.getCallingConv() == CallingConv::PTX_Kernel;
}
return (x == 1);
if (const auto X = findOneNVVMAnnotation(&F, "kernel"))
return (*X == 1);

// There is no NVVM metadata, check the calling convention
return F.getCallingConv() == CallingConv::PTX_Kernel;
}

MaybeAlign getAlign(const Function &F, unsigned Index) {
Expand Down
24 changes: 9 additions & 15 deletions llvm/lib/Target/NVPTX/NVPTXUtilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,6 @@ class TargetMachine;

void clearAnnotationCache(const Module *);

bool findOneNVVMAnnotation(const GlobalValue *, const std::string &,
unsigned &);
bool findAllNVVMAnnotation(const GlobalValue *, const std::string &,
std::vector<unsigned> &);

bool isTexture(const Value &);
bool isSurface(const Value &);
bool isSampler(const Value &);
Expand All @@ -45,23 +40,23 @@ bool isImageWriteOnly(const Value &);
bool isImageReadWrite(const Value &);
bool isManaged(const Value &);

std::string getTextureName(const Value &);
std::string getSurfaceName(const Value &);
std::string getSamplerName(const Value &);
StringRef getTextureName(const Value &);
StringRef getSurfaceName(const Value &);
StringRef getSamplerName(const Value &);

std::optional<unsigned> getMaxNTIDx(const Function &);
std::optional<unsigned> getMaxNTIDy(const Function &);
std::optional<unsigned> getMaxNTIDz(const Function &);
std::optional<unsigned> getMaxNTID(const Function &F);
std::optional<unsigned> getMaxNTID(const Function &);

std::optional<unsigned> getReqNTIDx(const Function &);
std::optional<unsigned> getReqNTIDy(const Function &);
std::optional<unsigned> getReqNTIDz(const Function &);
std::optional<unsigned> getReqNTID(const Function &);

bool getMaxClusterRank(const Function &, unsigned &);
bool getMinCTASm(const Function &, unsigned &);
bool getMaxNReg(const Function &, unsigned &);
std::optional<unsigned> getMaxClusterRank(const Function &);
std::optional<unsigned> getMinCTASm(const Function &);
std::optional<unsigned> getMaxNReg(const Function &);
bool isKernelFunction(const Function &);
bool isParamGridConstant(const Value &);

Expand All @@ -74,10 +69,9 @@ Function *getMaybeBitcastedCallee(const CallBase *CB);
inline unsigned promoteScalarArgumentSize(unsigned size) {
if (size <= 32)
return 32;
else if (size <= 64)
if (size <= 64)
return 64;
else
return size;
return size;
}

bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM);
Expand Down
Loading