Skip to content

[NVPTX] Basic support for "grid_constant" #96125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsNVVM.td
Original file line number Diff line number Diff line change
Expand Up @@ -1596,6 +1596,12 @@ def int_nvvm_ptr_gen_to_param: Intrinsic<[llvm_anyptr_ty],
[IntrNoMem, IntrSpeculatable, IntrNoCallback],
"llvm.nvvm.ptr.gen.to.param">;

// sm70+, PTX7.7+
def int_nvvm_ptr_param_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty],
[llvm_anyptr_ty],
[IntrNoMem, IntrSpeculatable, IntrNoCallback],
"llvm.nvvm.ptr.param.to.gen">;

// Move intrinsics, used in nvvm internally

def int_nvvm_move_i16 : Intrinsic<[llvm_i16_ty], [llvm_i16_ty], [IntrNoMem],
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -2475,6 +2475,7 @@ defm cvta_local : NG_TO_G<"local", int_nvvm_ptr_local_to_gen, useShortPtrLocal>
defm cvta_shared : NG_TO_G<"shared", int_nvvm_ptr_shared_to_gen, useShortPtrShared>;
defm cvta_global : NG_TO_G<"global", int_nvvm_ptr_global_to_gen, False>;
defm cvta_const : NG_TO_G<"const", int_nvvm_ptr_constant_to_gen, useShortPtrConst>;
defm cvta_param : NG_TO_G<"param", int_nvvm_ptr_param_to_gen, False>;

defm cvta_to_local : G_TO_NG<"local", int_nvvm_ptr_gen_to_local, useShortPtrLocal>;
defm cvta_to_shared : G_TO_NG<"shared", int_nvvm_ptr_gen_to_shared, useShortPtrShared>;
Expand Down
77 changes: 56 additions & 21 deletions llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/InitializePasses.h"
Expand Down Expand Up @@ -336,8 +338,9 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
while (!ValuesToCheck.empty()) {
Value *V = ValuesToCheck.pop_back_val();
if (!IsALoadChainInstr(V)) {
LLVM_DEBUG(dbgs() << "Need a copy of " << *Arg << " because of " << *V
<< "\n");
LLVM_DEBUG(dbgs() << "Need a "
<< (isParamGridConstant(*Arg) ? "cast " : "copy ")
<< "of " << *Arg << " because of " << *V << "\n");
(void)Arg;
return false;
}
Expand Down Expand Up @@ -366,27 +369,59 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
return;
}

// Otherwise we have to create a temporary copy.
const DataLayout &DL = Func->getParent()->getDataLayout();
unsigned AS = DL.getAllocaAddrSpace();
AllocaInst *AllocA = new AllocaInst(StructType, AS, Arg->getName(), FirstInst);
// Set the alignment to alignment of the byval parameter. This is because,
// later load/stores assume that alignment, and we are going to replace
// the use of the byval parameter with this alloca instruction.
AllocA->setAlignment(Func->getParamAlign(Arg->getArgNo())
.value_or(DL.getPrefTypeAlign(StructType)));
Arg->replaceAllUsesWith(AllocA);

Value *ArgInParam = new AddrSpaceCastInst(
Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
FirstInst);
// Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
// addrspacecast preserves alignment. Since params are constant, this load is
// definitely not volatile.
LoadInst *LI =
new LoadInst(StructType, ArgInParam, Arg->getName(),
/*isVolatile=*/false, AllocA->getAlign(), FirstInst);
new StoreInst(LI, AllocA, FirstInst);
if (isParamGridConstant(*Arg)) {
// Writes to a grid constant are undefined behaviour. We do not need a
// temporary copy. When a pointer might have escaped, conservatively replace
// all of its uses (which might include a device function call) with a cast
// to the generic address space.
// TODO: only cast byval grid constant parameters at use points that need
// generic address (e.g., merging parameter pointers with other address
// space, or escaping to call-sites, inline-asm, memory), and use the
// parameter address space for normal loads.
IRBuilder<> IRB(&Func->getEntryBlock().front());

// Cast argument to param address space
auto *CastToParam =
cast<AddrSpaceCastInst>(IRB.CreateAddrSpaceCast(
Arg, IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg->getName() + ".param"));

// Cast param address to generic address space. We do not use an
// addrspacecast to generic here, because, LLVM considers `Arg` to be in the
// generic address space, and a `generic -> param` cast followed by a `param
// -> generic` cast will be folded away. The `param -> generic` intrinsic
// will be correctly lowered to `cvta.param`.
Value *CvtToGenCall = IRB.CreateIntrinsic(
IRB.getPtrTy(ADDRESS_SPACE_GENERIC), Intrinsic::nvvm_ptr_param_to_gen,
CastToParam, nullptr, CastToParam->getName() + ".gen");

Arg->replaceAllUsesWith(CvtToGenCall);

// Do not replace Arg in the cast to param space
CastToParam->setOperand(0, Arg);
} else {
// Otherwise we have to create a temporary copy.
AllocaInst *AllocA =
new AllocaInst(StructType, AS, Arg->getName(), FirstInst);
// Set the alignment to alignment of the byval parameter. This is because,
// later load/stores assume that alignment, and we are going to replace
// the use of the byval parameter with this alloca instruction.
AllocA->setAlignment(Func->getParamAlign(Arg->getArgNo())
.value_or(DL.getPrefTypeAlign(StructType)));
Arg->replaceAllUsesWith(AllocA);

Value *ArgInParam = new AddrSpaceCastInst(
Arg, PointerType::get(Arg->getContext(), ADDRESS_SPACE_PARAM),
Arg->getName(), FirstInst);
// Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
// addrspacecast preserves alignment. Since params are constant, this load
// is definitely not volatile.
LoadInst *LI =
new LoadInst(StructType, ArgInParam, Arg->getName(),
/*isVolatile=*/false, AllocA->getAlign(), FirstInst);
new StoreInst(LI, AllocA, FirstInst);
}
}

void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) {
Expand Down
143 changes: 78 additions & 65 deletions llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,29 +52,46 @@ void clearAnnotationCache(const Module *Mod) {
AC.Cache.erase(Mod);
}

static void cacheAnnotationFromMD(const MDNode *md, key_val_pair_t &retval) {
static void readIntVecFromMDNode(const MDNode *MetadataNode,
std::vector<unsigned> &Vec) {
for (unsigned i = 0, e = MetadataNode->getNumOperands(); i != e; ++i) {
ConstantInt *Val =
mdconst::extract<ConstantInt>(MetadataNode->getOperand(i));
Vec.push_back(Val->getZExtValue());
}
}

static void cacheAnnotationFromMD(const MDNode *MetadataNode,
key_val_pair_t &retval) {
auto &AC = getAnnotationCache();
std::lock_guard<sys::Mutex> Guard(AC.Lock);
assert(md && "Invalid mdnode for annotation");
assert((md->getNumOperands() % 2) == 1 && "Invalid number of operands");
assert(MetadataNode && "Invalid mdnode for annotation");
assert((MetadataNode->getNumOperands() % 2) == 1 &&
"Invalid number of operands");
// start index = 1, to skip the global variable key
// increment = 2, to skip the value for each property-value pairs
for (unsigned i = 1, e = md->getNumOperands(); i != e; i += 2) {
for (unsigned i = 1, e = MetadataNode->getNumOperands(); i != e; i += 2) {
// property
const MDString *prop = dyn_cast<MDString>(md->getOperand(i));
const MDString *prop = dyn_cast<MDString>(MetadataNode->getOperand(i));
assert(prop && "Annotation property not a string");
std::string Key = prop->getString().str();

// value
ConstantInt *Val = mdconst::dyn_extract<ConstantInt>(md->getOperand(i + 1));
assert(Val && "Value operand not a constant int");

std::string keyname = prop->getString().str();
if (retval.find(keyname) != retval.end())
retval[keyname].push_back(Val->getZExtValue());
else {
std::vector<unsigned> tmp;
tmp.push_back(Val->getZExtValue());
retval[keyname] = tmp;
if (ConstantInt *Val = mdconst::dyn_extract<ConstantInt>(
MetadataNode->getOperand(i + 1))) {
retval[Key].push_back(Val->getZExtValue());
} else if (MDNode *VecMd =
dyn_cast<MDNode>(MetadataNode->getOperand(i + 1))) {
// note: only "grid_constant" annotations support vector MDNodes.
// assert: there can only exist one unique key value pair of
// the form (string key, MDNode node). Operands of such a node
// shall always be unsigned ints.
if (retval.find(Key) == retval.end()) {
readIntVecFromMDNode(VecMd, retval[Key]);
continue;
}
} else {
llvm_unreachable("Value operand not a constant int or an mdnode");
}
}
}
Expand Down Expand Up @@ -153,9 +170,9 @@ bool findAllNVVMAnnotation(const GlobalValue *gv, const std::string &prop,

bool isTexture(const Value &val) {
if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
unsigned annot;
if (findOneNVVMAnnotation(gv, "texture", annot)) {
assert((annot == 1) && "Unexpected annotation on a texture symbol");
unsigned Annot;
if (findOneNVVMAnnotation(gv, "texture", Annot)) {
assert((Annot == 1) && "Unexpected annotation on a texture symbol");
return true;
}
}
Expand All @@ -164,70 +181,67 @@ bool isTexture(const Value &val) {

bool isSurface(const Value &val) {
if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
unsigned annot;
if (findOneNVVMAnnotation(gv, "surface", annot)) {
assert((annot == 1) && "Unexpected annotation on a surface symbol");
unsigned Annot;
if (findOneNVVMAnnotation(gv, "surface", Annot)) {
assert((Annot == 1) && "Unexpected annotation on a surface symbol");
return true;
}
}
return false;
}

bool isSampler(const Value &val) {
const char *AnnotationName = "sampler";

if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
unsigned annot;
if (findOneNVVMAnnotation(gv, AnnotationName, annot)) {
assert((annot == 1) && "Unexpected annotation on a sampler symbol");
return true;
}
}
if (const Argument *arg = dyn_cast<Argument>(&val)) {
const Function *func = arg->getParent();
std::vector<unsigned> annot;
if (findAllNVVMAnnotation(func, AnnotationName, annot)) {
if (is_contained(annot, arg->getArgNo()))
static bool argHasNVVMAnnotation(const Value &Val,
const std::string &Annotation,
const bool StartArgIndexAtOne = false) {
if (const Argument *Arg = dyn_cast<Argument>(&Val)) {
const Function *Func = Arg->getParent();
std::vector<unsigned> Annot;
if (findAllNVVMAnnotation(Func, Annotation, Annot)) {
const unsigned BaseOffset = StartArgIndexAtOne ? 1 : 0;
if (is_contained(Annot, BaseOffset + Arg->getArgNo())) {
return true;
}
}
}
return false;
}

bool isImageReadOnly(const Value &val) {
if (const Argument *arg = dyn_cast<Argument>(&val)) {
const Function *func = arg->getParent();
std::vector<unsigned> annot;
if (findAllNVVMAnnotation(func, "rdoimage", annot)) {
if (is_contained(annot, arg->getArgNo()))
return true;
bool isParamGridConstant(const Value &V) {
if (const Argument *Arg = dyn_cast<Argument>(&V)) {
// "grid_constant" counts argument indices starting from 1
if (Arg->hasByValAttr() &&
argHasNVVMAnnotation(*Arg, "grid_constant", /*StartArgIndexAtOne*/true)) {
assert(isKernelFunction(*Arg->getParent()) &&
"only kernel arguments can be grid_constant");
return true;
}
}
return false;
}

bool isImageWriteOnly(const Value &val) {
if (const Argument *arg = dyn_cast<Argument>(&val)) {
const Function *func = arg->getParent();
std::vector<unsigned> annot;
if (findAllNVVMAnnotation(func, "wroimage", annot)) {
if (is_contained(annot, arg->getArgNo()))
return true;
bool isSampler(const Value &val) {
const char *AnnotationName = "sampler";

if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
unsigned Annot;
if (findOneNVVMAnnotation(gv, AnnotationName, Annot)) {
assert((Annot == 1) && "Unexpected annotation on a sampler symbol");
return true;
}
}
return false;
return argHasNVVMAnnotation(val, AnnotationName);
}

bool isImageReadOnly(const Value &val) {
return argHasNVVMAnnotation(val, "rdoimage");
}

bool isImageWriteOnly(const Value &val) {
return argHasNVVMAnnotation(val, "wroimage");
}

bool isImageReadWrite(const Value &val) {
if (const Argument *arg = dyn_cast<Argument>(&val)) {
const Function *func = arg->getParent();
std::vector<unsigned> annot;
if (findAllNVVMAnnotation(func, "rdwrimage", annot)) {
if (is_contained(annot, arg->getArgNo()))
return true;
}
}
return false;
return argHasNVVMAnnotation(val, "rdwrimage");
}

bool isImage(const Value &val) {
Expand All @@ -236,9 +250,9 @@ bool isImage(const Value &val) {

bool isManaged(const Value &val) {
if(const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
unsigned annot;
if (findOneNVVMAnnotation(gv, "managed", annot)) {
assert((annot == 1) && "Unexpected annotation on a managed symbol");
unsigned Annot;
if (findOneNVVMAnnotation(gv, "managed", Annot)) {
assert((Annot == 1) && "Unexpected annotation on a managed symbol");
return true;
}
}
Expand Down Expand Up @@ -323,8 +337,7 @@ bool getMaxNReg(const Function &F, unsigned &x) {

bool isKernelFunction(const Function &F) {
unsigned x = 0;
bool retval = findOneNVVMAnnotation(&F, "kernel", x);
if (!retval) {
if (!findOneNVVMAnnotation(&F, "kernel", x)) {
// There is no NVVM metadata, check the calling convention
return F.getCallingConv() == CallingConv::PTX_Kernel;
}
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/NVPTX/NVPTXUtilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ bool getMaxClusterRank(const Function &, unsigned &);
bool getMinCTASm(const Function &, unsigned &);
bool getMaxNReg(const Function &, unsigned &);
bool isKernelFunction(const Function &);
bool isParamGridConstant(const Value &);

MaybeAlign getAlign(const Function &, unsigned);
MaybeAlign getAlign(const CallInst &, unsigned);
Expand Down
Loading
Loading