Skip to content

Commit 637d3d5

Browse files
authored
[SYCL][ESIMD] Fix GenXIntrinsic lowering for vectors used as scalars (#9211)
In most cases, we see the following IR for vector spirv intrinsics: ``` %0 = load <3 x i64>, <3 x i64> addrspace(4)* @__spirv_BuiltInGlobalInvocationId, align 32 %1 = extractelement <3 x i64> %0, i64 0 // some use of %1, ex: %2 = icmp ult i64 %0, 50 ``` However, if only the first element is used, this may be optimized to: ``` %0 = load i64, addrspace(4)* @__spirv_BuiltInGlobalInvocationId, align 32 // some use of %0, ex: %1 = icmp ult i64 %0, 50 ``` In the latter case, if we try to insert a cast due to different types between the real intrinsic type and the global variable type, make sure to use the load type, because the use type is arbitrary code and may not be the same as the global variable type. This is fine for the case with ExtractElement because in that case the type of the EE will be based on the global variable type. --------- Signed-off-by: Sarnie, Nick <[email protected]>
1 parent f81499e commit 637d3d5

File tree

2 files changed

+52
-16
lines changed

2 files changed

+52
-16
lines changed

llvm/lib/SYCLLowerIR/ESIMD/LowerESIMD.cpp

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,9 +1059,10 @@ static void translateGetSurfaceIndex(CallInst &CI) {
10591059
// This helper function creates cast operation from GenX intrinsic return type
10601060
// to currently expected. Returns pointer to created cast instruction if it
10611061
// was created, otherwise returns NewI.
1062-
static Instruction *addCastInstIfNeeded(Instruction *OldI, Instruction *NewI) {
1062+
static Instruction *addCastInstIfNeeded(Instruction *OldI, Instruction *NewI,
1063+
Type *UseType = nullptr) {
10631064
Type *NITy = NewI->getType();
1064-
Type *OITy = OldI->getType();
1065+
Type *OITy = UseType ? UseType : OldI->getType();
10651066
if (OITy != NITy) {
10661067
auto CastOpcode = CastInst::getCastOpcode(NewI, false, OITy, false);
10671068
NewI = CastInst::Create(CastOpcode, NewI, OITy,
@@ -1086,7 +1087,8 @@ static uint64_t getIndexFromExtract(ExtractElementInst *EEI) {
10861087
/// of vector load. The parameter \p IsVectorCall tells what version of GenX
10871088
/// intrinsic (scalar or vector) to use to lower the load from SPIRV global.
10881089
static Instruction *generateGenXCall(Instruction *EEI, StringRef IntrinName,
1089-
bool IsVectorCall, uint64_t IndexValue) {
1090+
bool IsVectorCall, uint64_t IndexValue,
1091+
Type *UseType) {
10901092
std::string Suffix =
10911093
IsVectorCall
10921094
? ".v3i32"
@@ -1125,7 +1127,7 @@ static Instruction *generateGenXCall(Instruction *EEI, StringRef IntrinName,
11251127
ExtractName, EEI);
11261128
Inst->setDebugLoc(EEI->getDebugLoc());
11271129
}
1128-
Inst = addCastInstIfNeeded(EEI, Inst);
1130+
Inst = addCastInstIfNeeded(EEI, Inst, UseType);
11291131
return Inst;
11301132
}
11311133

@@ -1166,37 +1168,39 @@ bool translateLLVMIntrinsic(CallInst *CI) {
11661168
// Generate translation instructions for SPIRV global function calls
11671169
static Value *generateSpirvGlobalGenX(Instruction *EEI,
11681170
StringRef SpirvGlobalName,
1169-
uint64_t IndexValue) {
1171+
uint64_t IndexValue,
1172+
Type *UseType = nullptr) {
11701173
Value *NewInst = nullptr;
11711174
if (SpirvGlobalName == "WorkgroupSize") {
1172-
NewInst = generateGenXCall(EEI, "local.size", true, IndexValue);
1175+
NewInst = generateGenXCall(EEI, "local.size", true, IndexValue, UseType);
11731176
} else if (SpirvGlobalName == "LocalInvocationId") {
1174-
NewInst = generateGenXCall(EEI, "local.id", true, IndexValue);
1177+
NewInst = generateGenXCall(EEI, "local.id", true, IndexValue, UseType);
11751178
} else if (SpirvGlobalName == "WorkgroupId") {
1176-
NewInst = generateGenXCall(EEI, "group.id", false, IndexValue);
1179+
NewInst = generateGenXCall(EEI, "group.id", false, IndexValue, UseType);
11771180
} else if (SpirvGlobalName == "GlobalInvocationId") {
11781181
// GlobalId = LocalId + WorkGroupSize * GroupId
1179-
Instruction *LocalIdI = generateGenXCall(EEI, "local.id", true, IndexValue);
1182+
Instruction *LocalIdI =
1183+
generateGenXCall(EEI, "local.id", true, IndexValue, UseType);
11801184
Instruction *WGSizeI =
1181-
generateGenXCall(EEI, "local.size", true, IndexValue);
1185+
generateGenXCall(EEI, "local.size", true, IndexValue, UseType);
11821186
Instruction *GroupIdI =
1183-
generateGenXCall(EEI, "group.id", false, IndexValue);
1187+
generateGenXCall(EEI, "group.id", false, IndexValue, UseType);
11841188
Instruction *MulI =
11851189
BinaryOperator::CreateMul(WGSizeI, GroupIdI, "mul", EEI);
11861190
NewInst = BinaryOperator::CreateAdd(LocalIdI, MulI, "add", EEI);
11871191
} else if (SpirvGlobalName == "GlobalSize") {
11881192
// GlobalSize = WorkGroupSize * NumWorkGroups
11891193
Instruction *WGSizeI =
1190-
generateGenXCall(EEI, "local.size", true, IndexValue);
1194+
generateGenXCall(EEI, "local.size", true, IndexValue, UseType);
11911195
Instruction *NumWGI =
1192-
generateGenXCall(EEI, "group.count", true, IndexValue);
1196+
generateGenXCall(EEI, "group.count", true, IndexValue, UseType);
11931197
NewInst = BinaryOperator::CreateMul(WGSizeI, NumWGI, "mul", EEI);
11941198
} else if (SpirvGlobalName == "GlobalOffset") {
11951199
// TODO: Support GlobalOffset SPIRV intrinsics
11961200
// Currently all users of load of GlobalOffset are replaced with 0.
11971201
NewInst = llvm::Constant::getNullValue(EEI->getType());
11981202
} else if (SpirvGlobalName == "NumWorkgroups") {
1199-
NewInst = generateGenXCall(EEI, "group.count", true, IndexValue);
1203+
NewInst = generateGenXCall(EEI, "group.count", true, IndexValue, UseType);
12001204
}
12011205

12021206
llvm::esimd::assert_and_diag(
@@ -1255,8 +1259,8 @@ translateSpirvGlobalUses(LoadInst *LI, StringRef SpirvGlobalName,
12551259
SmallVector<User *> Users(LI->users());
12561260
for (User *LU : Users) {
12571261
Instruction *Inst = cast<Instruction>(LU);
1258-
NewInst =
1259-
generateSpirvGlobalGenX(Inst, SpirvGlobalName, /*IndexValue=*/0);
1262+
NewInst = generateSpirvGlobalGenX(Inst, SpirvGlobalName, /*IndexValue=*/0,
1263+
LI->getType());
12601264
LU->replaceUsesOfWith(LI, NewInst);
12611265
}
12621266
InstsToErase.push_back(LI);
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
; RUN: opt -opaque-pointers < %s -passes=LowerESIMD -S | FileCheck %s
2+
3+
; This test checks we lower vector SPIRV globals correctly if
4+
; it is accessed as a scalar as an optimization to get the first element and needs a cast
5+
6+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
7+
target triple = "spir64-unknown-unknown"
8+
9+
@__spirv_BuiltInGlobalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32
10+
11+
define spir_kernel void @"__spirv_GlobalInvocationId_xyz"(i64 addrspace(1)* %_arg_) {
12+
; CHECK-LABEL: @__spirv_GlobalInvocationId_xyz(
13+
; CHECK-NEXT: entry:
14+
; CHECK-NEXT: [[DOTESIMD6:%.*]] = call <3 x i32> @llvm.genx.local.id.v3i32()
15+
; CHECK-NEXT: [[LOCAL_ID_X:%.*]] = extractelement <3 x i32> [[DOTESIMD6]], i32 0
16+
; CHECK-NEXT: [[LOCAL_ID_X_CAST_TY:%.*]] = zext i32 [[LOCAL_ID_X]] to i64
17+
; CHECK-NEXT: [[DOTESIMD7:%.*]] = call <3 x i32> @llvm.genx.local.size.v3i32()
18+
; CHECK-NEXT: [[WGSIZE_X:%.*]] = extractelement <3 x i32> [[DOTESIMD7]], i32 0
19+
; CHECK-NEXT: [[WGSIZE_X_CAST_TY:%.*]] = zext i32 [[WGSIZE_X]] to i64
20+
; CHECK-NEXT: [[GROUP_ID_X:%.*]] = call i32 @llvm.genx.group.id.x()
21+
; CHECK-NEXT: [[GROUP_ID_X_CAST_TY:%.*]] = zext i32 [[GROUP_ID_X]] to i64
22+
; CHECK-NEXT: [[MUL8:%.*]] = mul i64 [[WGSIZE_X_CAST_TY]], [[GROUP_ID_X_CAST_TY]]
23+
; CHECK-NEXT: [[ADD9:%.*]] = add i64 [[LOCAL_ID_X_CAST_TY]], [[MUL8]]
24+
; CHECK-NEXT: [[CAST10:%.*]] = icmp ult i64 [[ADD9]], 0
25+
26+
; Verify that the attribute is deleted from GenX declaration
27+
; CHECK-NOT: readnone
28+
entry:
29+
%0 = load i64, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32
30+
%cmp.not.i = icmp ult i64 %0, 0
31+
ret void
32+
}

0 commit comments

Comments
 (0)