Skip to content

[SYCL][ESIMD] Fix crashes of sycl-post-link #7673

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 62 additions & 56 deletions llvm/lib/SYCLLowerIR/ESIMD/LowerESIMD.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ using namespace llvm;
namespace id = itanium_demangle;
using namespace llvm::esimd;

#undef DEBUG_TYPE
#define DEBUG_TYPE "lower-esimd"

#define SLM_BTI 254
Expand Down Expand Up @@ -1130,8 +1131,11 @@ static Instruction *generateGenXCall(Instruction *EEI, StringRef IntrinName,
? GenXIntrinsic::getGenXDeclaration(
EEI->getModule(), ID, FixedVectorType::get(I32Ty, MAX_DIMS))
: GenXIntrinsic::getGenXDeclaration(EEI->getModule(), ID);

std::string ResultName = (Twine(EEI->getName()) + "." + FullIntrinName).str();
// Use hardcoded prefix when EEI has no name.
std::string ResultName =
((EEI->hasName() ? Twine(EEI->getName()) : Twine("Res")) + "." +
FullIntrinName)
.str();
Instruction *Inst = IntrinsicInst::Create(NewFDecl, {}, ResultName, EEI);
Inst->setDebugLoc(EEI->getDebugLoc());

Expand Down Expand Up @@ -1181,6 +1185,48 @@ bool translateLLVMIntrinsic(CallInst *CI) {
return true; // "intrinsic has been translated, erase the original call"
}

// Generate translation instructions for SPIRV global function calls
static Value *generateSpirvGlobalGenX(Instruction *EEI,
StringRef SpirvGlobalName,
uint64_t IndexValue) {
Value *NewInst = nullptr;
if (SpirvGlobalName == "WorkgroupSize") {
NewInst = generateGenXCall(EEI, "local.size", true, IndexValue);
} else if (SpirvGlobalName == "LocalInvocationId") {
NewInst = generateGenXCall(EEI, "local.id", true, IndexValue);
} else if (SpirvGlobalName == "WorkgroupId") {
NewInst = generateGenXCall(EEI, "group.id", false, IndexValue);
} else if (SpirvGlobalName == "GlobalInvocationId") {
// GlobalId = LocalId + WorkGroupSize * GroupId
Instruction *LocalIdI = generateGenXCall(EEI, "local.id", true, IndexValue);
Instruction *WGSizeI =
generateGenXCall(EEI, "local.size", true, IndexValue);
Instruction *GroupIdI =
generateGenXCall(EEI, "group.id", false, IndexValue);
Instruction *MulI =
BinaryOperator::CreateMul(WGSizeI, GroupIdI, "mul", EEI);
NewInst = BinaryOperator::CreateAdd(LocalIdI, MulI, "add", EEI);
} else if (SpirvGlobalName == "GlobalSize") {
// GlobalSize = WorkGroupSize * NumWorkGroups
Instruction *WGSizeI =
generateGenXCall(EEI, "local.size", true, IndexValue);
Instruction *NumWGI =
generateGenXCall(EEI, "group.count", true, IndexValue);
NewInst = BinaryOperator::CreateMul(WGSizeI, NumWGI, "mul", EEI);
} else if (SpirvGlobalName == "GlobalOffset") {
// TODO: Support GlobalOffset SPIRV intrinsics
// Currently all users of load of GlobalOffset are replaced with 0.
NewInst = llvm::Constant::getNullValue(EEI->getType());
} else if (SpirvGlobalName == "NumWorkgroups") {
NewInst = generateGenXCall(EEI, "group.count", true, IndexValue);
}

llvm::esimd::assert_and_diag(
NewInst, "Load from global SPIRV builtin was not translated");

return NewInst;
}

/// Replaces the load \p LI of SPIRV global with corresponding call(s) of GenX
/// intrinsic(s). The users of \p LI may also be transformed if needed for
/// def/use type correctness.
Expand All @@ -1207,6 +1253,12 @@ translateSpirvGlobalUses(LoadInst *LI, StringRef SpirvGlobalName,
llvm::APInt(32, 1, true));
} else if (SpirvGlobalName == "GlobalLinearId") {
NewInst = llvm::Constant::getNullValue(LI->getType());
} else if (isa<GetElementPtrConstantExpr>(LI->getPointerOperand())) {
// Translate the load that has getelementptr as an operand
auto *GEPCE = cast<GetElementPtrConstantExpr>(LI->getPointerOperand());
uint64_t IndexValue =
cast<Constant>(GEPCE->getOperand(2))->getUniqueInteger().getZExtValue();
NewInst = generateSpirvGlobalGenX(LI, SpirvGlobalName, IndexValue);
}
if (NewInst) {
LI->replaceAllUsesWith(NewInst);
Expand All @@ -1215,62 +1267,16 @@ translateSpirvGlobalUses(LoadInst *LI, StringRef SpirvGlobalName,
}

// Only loads from _vector_ SPIRV globals reach here now. Their users are
// expected to be ExtractElementInst or TruncInst only, and they are replaced
// in this loop. When loads from _scalar_ SPIRV globals are handled here as
// well, the users will not be replaced by new instructions, but the GenX call
// replacing the original load 'LI' should be inserted before each user.
// expected to be ExtractElementInst only, and they are
// replaced in this loop. When loads from _scalar_ SPIRV globals are handled
// here as well, the users will not be replaced by new instructions, but the
// GenX call replacing the original load 'LI' should be inserted before each
// user.
for (User *LU : LI->users()) {
assert(
(isa<ExtractElementInst>(LU) || isa<TruncInst>(LU)) &&
"SPIRV global users should be either ExtractElementInst or TruncInst");
Instruction *EEI = cast<Instruction>(LU);
NewInst = nullptr;

uint64_t IndexValue = 0;
if (isa<ExtractElementInst>(EEI)) {
IndexValue = getIndexFromExtract(cast<ExtractElementInst>(EEI));
} else {
auto *GEPCE = cast<GetElementPtrConstantExpr>(LI->getPointerOperand());

IndexValue = cast<Constant>(GEPCE->getOperand(2))
->getUniqueInteger()
.getZExtValue();
}

if (SpirvGlobalName == "WorkgroupSize") {
NewInst = generateGenXCall(EEI, "local.size", true, IndexValue);
} else if (SpirvGlobalName == "LocalInvocationId") {
NewInst = generateGenXCall(EEI, "local.id", true, IndexValue);
} else if (SpirvGlobalName == "WorkgroupId") {
NewInst = generateGenXCall(EEI, "group.id", false, IndexValue);
} else if (SpirvGlobalName == "GlobalInvocationId") {
// GlobalId = LocalId + WorkGroupSize * GroupId
Instruction *LocalIdI =
generateGenXCall(EEI, "local.id", true, IndexValue);
Instruction *WGSizeI =
generateGenXCall(EEI, "local.size", true, IndexValue);
Instruction *GroupIdI =
generateGenXCall(EEI, "group.id", false, IndexValue);
Instruction *MulI =
BinaryOperator::CreateMul(WGSizeI, GroupIdI, "mul", EEI);
NewInst = BinaryOperator::CreateAdd(LocalIdI, MulI, "add", EEI);
} else if (SpirvGlobalName == "GlobalSize") {
// GlobalSize = WorkGroupSize * NumWorkGroups
Instruction *WGSizeI =
generateGenXCall(EEI, "local.size", true, IndexValue);
Instruction *NumWGI =
generateGenXCall(EEI, "group.count", true, IndexValue);
NewInst = BinaryOperator::CreateMul(WGSizeI, NumWGI, "mul", EEI);
} else if (SpirvGlobalName == "GlobalOffset") {
// TODO: Support GlobalOffset SPIRV intrinsics
// Currently all users of load of GlobalOffset are replaced with 0.
NewInst = llvm::Constant::getNullValue(EEI->getType());
} else if (SpirvGlobalName == "NumWorkgroups") {
NewInst = generateGenXCall(EEI, "group.count", true, IndexValue);
}
ExtractElementInst *EEI = cast<ExtractElementInst>(LU);
uint64_t IndexValue = getIndexFromExtract(cast<ExtractElementInst>(EEI));

llvm::esimd::assert_and_diag(
NewInst, "Load from global SPIRV builtin was not translated");
NewInst = generateSpirvGlobalGenX(EEI, SpirvGlobalName, IndexValue);
EEI->replaceAllUsesWith(NewInst);
InstsToErase.push_back(EEI);
}
Expand Down
Loading