Skip to content

Commit 5de8d26

Browse files
authored
[SYCL][ESIMD] Fix crashes of sycl-post-link (#7673)
This is to have sycl-post-link to handle IR in the form ``` %0 = load i64, i64 addrspace(1)* getelementptr (<3 x i64>, <3 x i64> addrspace(1)* @__spirv_BuiltInGlobalInvocationId, i64 0, i64 0), align 32 store i64 %0, i64 addrspace(1)* %_arg_DoNotOptimize, align 8 ```
1 parent 1dcd645 commit 5de8d26

File tree

2 files changed

+460
-63
lines changed

2 files changed

+460
-63
lines changed

llvm/lib/SYCLLowerIR/ESIMD/LowerESIMD.cpp

Lines changed: 62 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ using namespace llvm;
4444
namespace id = itanium_demangle;
4545
using namespace llvm::esimd;
4646

47+
#undef DEBUG_TYPE
4748
#define DEBUG_TYPE "lower-esimd"
4849

4950
#define SLM_BTI 254
@@ -1130,8 +1131,11 @@ static Instruction *generateGenXCall(Instruction *EEI, StringRef IntrinName,
11301131
? GenXIntrinsic::getGenXDeclaration(
11311132
EEI->getModule(), ID, FixedVectorType::get(I32Ty, MAX_DIMS))
11321133
: GenXIntrinsic::getGenXDeclaration(EEI->getModule(), ID);
1133-
1134-
std::string ResultName = (Twine(EEI->getName()) + "." + FullIntrinName).str();
1134+
// Use hardcoded prefix when EEI has no name.
1135+
std::string ResultName =
1136+
((EEI->hasName() ? Twine(EEI->getName()) : Twine("Res")) + "." +
1137+
FullIntrinName)
1138+
.str();
11351139
Instruction *Inst = IntrinsicInst::Create(NewFDecl, {}, ResultName, EEI);
11361140
Inst->setDebugLoc(EEI->getDebugLoc());
11371141

@@ -1181,6 +1185,48 @@ bool translateLLVMIntrinsic(CallInst *CI) {
11811185
return true; // "intrinsic has been translated, erase the original call"
11821186
}
11831187

1188+
// Generate translation instructions for SPIRV global function calls
1189+
static Value *generateSpirvGlobalGenX(Instruction *EEI,
1190+
StringRef SpirvGlobalName,
1191+
uint64_t IndexValue) {
1192+
Value *NewInst = nullptr;
1193+
if (SpirvGlobalName == "WorkgroupSize") {
1194+
NewInst = generateGenXCall(EEI, "local.size", true, IndexValue);
1195+
} else if (SpirvGlobalName == "LocalInvocationId") {
1196+
NewInst = generateGenXCall(EEI, "local.id", true, IndexValue);
1197+
} else if (SpirvGlobalName == "WorkgroupId") {
1198+
NewInst = generateGenXCall(EEI, "group.id", false, IndexValue);
1199+
} else if (SpirvGlobalName == "GlobalInvocationId") {
1200+
// GlobalId = LocalId + WorkGroupSize * GroupId
1201+
Instruction *LocalIdI = generateGenXCall(EEI, "local.id", true, IndexValue);
1202+
Instruction *WGSizeI =
1203+
generateGenXCall(EEI, "local.size", true, IndexValue);
1204+
Instruction *GroupIdI =
1205+
generateGenXCall(EEI, "group.id", false, IndexValue);
1206+
Instruction *MulI =
1207+
BinaryOperator::CreateMul(WGSizeI, GroupIdI, "mul", EEI);
1208+
NewInst = BinaryOperator::CreateAdd(LocalIdI, MulI, "add", EEI);
1209+
} else if (SpirvGlobalName == "GlobalSize") {
1210+
// GlobalSize = WorkGroupSize * NumWorkGroups
1211+
Instruction *WGSizeI =
1212+
generateGenXCall(EEI, "local.size", true, IndexValue);
1213+
Instruction *NumWGI =
1214+
generateGenXCall(EEI, "group.count", true, IndexValue);
1215+
NewInst = BinaryOperator::CreateMul(WGSizeI, NumWGI, "mul", EEI);
1216+
} else if (SpirvGlobalName == "GlobalOffset") {
1217+
// TODO: Support GlobalOffset SPIRV intrinsics
1218+
// Currently all users of load of GlobalOffset are replaced with 0.
1219+
NewInst = llvm::Constant::getNullValue(EEI->getType());
1220+
} else if (SpirvGlobalName == "NumWorkgroups") {
1221+
NewInst = generateGenXCall(EEI, "group.count", true, IndexValue);
1222+
}
1223+
1224+
llvm::esimd::assert_and_diag(
1225+
NewInst, "Load from global SPIRV builtin was not translated");
1226+
1227+
return NewInst;
1228+
}
1229+
11841230
/// Replaces the load \p LI of SPIRV global with corresponding call(s) of GenX
11851231
/// intrinsic(s). The users of \p LI may also be transformed if needed for
11861232
/// def/use type correctness.
@@ -1207,6 +1253,12 @@ translateSpirvGlobalUses(LoadInst *LI, StringRef SpirvGlobalName,
12071253
llvm::APInt(32, 1, true));
12081254
} else if (SpirvGlobalName == "GlobalLinearId") {
12091255
NewInst = llvm::Constant::getNullValue(LI->getType());
1256+
} else if (isa<GetElementPtrConstantExpr>(LI->getPointerOperand())) {
1257+
// Translate the load that has getelementptr as an operand
1258+
auto *GEPCE = cast<GetElementPtrConstantExpr>(LI->getPointerOperand());
1259+
uint64_t IndexValue =
1260+
cast<Constant>(GEPCE->getOperand(2))->getUniqueInteger().getZExtValue();
1261+
NewInst = generateSpirvGlobalGenX(LI, SpirvGlobalName, IndexValue);
12101262
}
12111263
if (NewInst) {
12121264
LI->replaceAllUsesWith(NewInst);
@@ -1215,62 +1267,16 @@ translateSpirvGlobalUses(LoadInst *LI, StringRef SpirvGlobalName,
12151267
}
12161268

12171269
// Only loads from _vector_ SPIRV globals reach here now. Their users are
1218-
// expected to be ExtractElementInst or TruncInst only, and they are replaced
1219-
// in this loop. When loads from _scalar_ SPIRV globals are handled here as
1220-
// well, the users will not be replaced by new instructions, but the GenX call
1221-
// replacing the original load 'LI' should be inserted before each user.
1270+
// expected to be ExtractElementInst only, and they are
1271+
// replaced in this loop. When loads from _scalar_ SPIRV globals are handled
1272+
// here as well, the users will not be replaced by new instructions, but the
1273+
// GenX call replacing the original load 'LI' should be inserted before each
1274+
// user.
12221275
for (User *LU : LI->users()) {
1223-
assert(
1224-
(isa<ExtractElementInst>(LU) || isa<TruncInst>(LU)) &&
1225-
"SPIRV global users should be either ExtractElementInst or TruncInst");
1226-
Instruction *EEI = cast<Instruction>(LU);
1227-
NewInst = nullptr;
1228-
1229-
uint64_t IndexValue = 0;
1230-
if (isa<ExtractElementInst>(EEI)) {
1231-
IndexValue = getIndexFromExtract(cast<ExtractElementInst>(EEI));
1232-
} else {
1233-
auto *GEPCE = cast<GetElementPtrConstantExpr>(LI->getPointerOperand());
1234-
1235-
IndexValue = cast<Constant>(GEPCE->getOperand(2))
1236-
->getUniqueInteger()
1237-
.getZExtValue();
1238-
}
1239-
1240-
if (SpirvGlobalName == "WorkgroupSize") {
1241-
NewInst = generateGenXCall(EEI, "local.size", true, IndexValue);
1242-
} else if (SpirvGlobalName == "LocalInvocationId") {
1243-
NewInst = generateGenXCall(EEI, "local.id", true, IndexValue);
1244-
} else if (SpirvGlobalName == "WorkgroupId") {
1245-
NewInst = generateGenXCall(EEI, "group.id", false, IndexValue);
1246-
} else if (SpirvGlobalName == "GlobalInvocationId") {
1247-
// GlobalId = LocalId + WorkGroupSize * GroupId
1248-
Instruction *LocalIdI =
1249-
generateGenXCall(EEI, "local.id", true, IndexValue);
1250-
Instruction *WGSizeI =
1251-
generateGenXCall(EEI, "local.size", true, IndexValue);
1252-
Instruction *GroupIdI =
1253-
generateGenXCall(EEI, "group.id", false, IndexValue);
1254-
Instruction *MulI =
1255-
BinaryOperator::CreateMul(WGSizeI, GroupIdI, "mul", EEI);
1256-
NewInst = BinaryOperator::CreateAdd(LocalIdI, MulI, "add", EEI);
1257-
} else if (SpirvGlobalName == "GlobalSize") {
1258-
// GlobalSize = WorkGroupSize * NumWorkGroups
1259-
Instruction *WGSizeI =
1260-
generateGenXCall(EEI, "local.size", true, IndexValue);
1261-
Instruction *NumWGI =
1262-
generateGenXCall(EEI, "group.count", true, IndexValue);
1263-
NewInst = BinaryOperator::CreateMul(WGSizeI, NumWGI, "mul", EEI);
1264-
} else if (SpirvGlobalName == "GlobalOffset") {
1265-
// TODO: Support GlobalOffset SPIRV intrinsics
1266-
// Currently all users of load of GlobalOffset are replaced with 0.
1267-
NewInst = llvm::Constant::getNullValue(EEI->getType());
1268-
} else if (SpirvGlobalName == "NumWorkgroups") {
1269-
NewInst = generateGenXCall(EEI, "group.count", true, IndexValue);
1270-
}
1276+
ExtractElementInst *EEI = cast<ExtractElementInst>(LU);
1277+
uint64_t IndexValue = getIndexFromExtract(cast<ExtractElementInst>(EEI));
12711278

1272-
llvm::esimd::assert_and_diag(
1273-
NewInst, "Load from global SPIRV builtin was not translated");
1279+
NewInst = generateSpirvGlobalGenX(EEI, SpirvGlobalName, IndexValue);
12741280
EEI->replaceAllUsesWith(NewInst);
12751281
InstsToErase.push_back(EEI);
12761282
}

0 commit comments

Comments
 (0)