Skip to content

[SYCL][ESIMD] Fix GenXIntrinsic lowering for vectors used as scalars #9211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 20 additions & 16 deletions llvm/lib/SYCLLowerIR/ESIMD/LowerESIMD.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1059,9 +1059,10 @@ static void translateGetSurfaceIndex(CallInst &CI) {
// This helper function creates cast operation from GenX intrinsic return type
// to currently expected. Returns pointer to created cast instruction if it
// was created, otherwise returns NewI.
static Instruction *addCastInstIfNeeded(Instruction *OldI, Instruction *NewI) {
static Instruction *addCastInstIfNeeded(Instruction *OldI, Instruction *NewI,
Type *UseType = nullptr) {
Type *NITy = NewI->getType();
Type *OITy = OldI->getType();
Type *OITy = UseType ? UseType : OldI->getType();
if (OITy != NITy) {
auto CastOpcode = CastInst::getCastOpcode(NewI, false, OITy, false);
NewI = CastInst::Create(CastOpcode, NewI, OITy,
Expand All @@ -1086,7 +1087,8 @@ static uint64_t getIndexFromExtract(ExtractElementInst *EEI) {
/// of vector load. The parameter \p IsVectorCall tells what version of GenX
/// intrinsic (scalar or vector) to use to lower the load from SPIRV global.
static Instruction *generateGenXCall(Instruction *EEI, StringRef IntrinName,
bool IsVectorCall, uint64_t IndexValue) {
bool IsVectorCall, uint64_t IndexValue,
Type *UseType) {
std::string Suffix =
IsVectorCall
? ".v3i32"
Expand Down Expand Up @@ -1125,7 +1127,7 @@ static Instruction *generateGenXCall(Instruction *EEI, StringRef IntrinName,
ExtractName, EEI);
Inst->setDebugLoc(EEI->getDebugLoc());
}
Inst = addCastInstIfNeeded(EEI, Inst);
Inst = addCastInstIfNeeded(EEI, Inst, UseType);
return Inst;
}

Expand Down Expand Up @@ -1166,37 +1168,39 @@ bool translateLLVMIntrinsic(CallInst *CI) {
// Generate translation instructions for SPIRV global function calls
static Value *generateSpirvGlobalGenX(Instruction *EEI,
StringRef SpirvGlobalName,
uint64_t IndexValue) {
uint64_t IndexValue,
Type *UseType = nullptr) {
Value *NewInst = nullptr;
if (SpirvGlobalName == "WorkgroupSize") {
NewInst = generateGenXCall(EEI, "local.size", true, IndexValue);
NewInst = generateGenXCall(EEI, "local.size", true, IndexValue, UseType);
} else if (SpirvGlobalName == "LocalInvocationId") {
NewInst = generateGenXCall(EEI, "local.id", true, IndexValue);
NewInst = generateGenXCall(EEI, "local.id", true, IndexValue, UseType);
} else if (SpirvGlobalName == "WorkgroupId") {
NewInst = generateGenXCall(EEI, "group.id", false, IndexValue);
NewInst = generateGenXCall(EEI, "group.id", false, IndexValue, UseType);
} else if (SpirvGlobalName == "GlobalInvocationId") {
// GlobalId = LocalId + WorkGroupSize * GroupId
Instruction *LocalIdI = generateGenXCall(EEI, "local.id", true, IndexValue);
Instruction *LocalIdI =
generateGenXCall(EEI, "local.id", true, IndexValue, UseType);
Instruction *WGSizeI =
generateGenXCall(EEI, "local.size", true, IndexValue);
generateGenXCall(EEI, "local.size", true, IndexValue, UseType);
Instruction *GroupIdI =
generateGenXCall(EEI, "group.id", false, IndexValue);
generateGenXCall(EEI, "group.id", false, IndexValue, UseType);
Instruction *MulI =
BinaryOperator::CreateMul(WGSizeI, GroupIdI, "mul", EEI);
NewInst = BinaryOperator::CreateAdd(LocalIdI, MulI, "add", EEI);
} else if (SpirvGlobalName == "GlobalSize") {
// GlobalSize = WorkGroupSize * NumWorkGroups
Instruction *WGSizeI =
generateGenXCall(EEI, "local.size", true, IndexValue);
generateGenXCall(EEI, "local.size", true, IndexValue, UseType);
Instruction *NumWGI =
generateGenXCall(EEI, "group.count", true, IndexValue);
generateGenXCall(EEI, "group.count", true, IndexValue, UseType);
NewInst = BinaryOperator::CreateMul(WGSizeI, NumWGI, "mul", EEI);
} else if (SpirvGlobalName == "GlobalOffset") {
// TODO: Support GlobalOffset SPIRV intrinsics
// Currently all users of load of GlobalOffset are replaced with 0.
NewInst = llvm::Constant::getNullValue(EEI->getType());
} else if (SpirvGlobalName == "NumWorkgroups") {
NewInst = generateGenXCall(EEI, "group.count", true, IndexValue);
NewInst = generateGenXCall(EEI, "group.count", true, IndexValue, UseType);
}

llvm::esimd::assert_and_diag(
Expand Down Expand Up @@ -1255,8 +1259,8 @@ translateSpirvGlobalUses(LoadInst *LI, StringRef SpirvGlobalName,
SmallVector<User *> Users(LI->users());
for (User *LU : Users) {
Instruction *Inst = cast<Instruction>(LU);
NewInst =
generateSpirvGlobalGenX(Inst, SpirvGlobalName, /*IndexValue=*/0);
NewInst = generateSpirvGlobalGenX(Inst, SpirvGlobalName, /*IndexValue=*/0,
LI->getType());
LU->replaceUsesOfWith(LI, NewInst);
}
InstsToErase.push_back(LI);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
; RUN: opt -opaque-pointers < %s -passes=LowerESIMD -S | FileCheck %s

; This test checks we lower vector SPIRV globals correctly if
; it is accessed as a scalar as an optimization to get the first element and needs a cast

target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown"

@__spirv_BuiltInGlobalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32

define spir_kernel void @"__spirv_GlobalInvocationId_xyz"(i64 addrspace(1)* %_arg_) {
; CHECK-LABEL: @__spirv_GlobalInvocationId_xyz(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[DOTESIMD6:%.*]] = call <3 x i32> @llvm.genx.local.id.v3i32()
; CHECK-NEXT: [[LOCAL_ID_X:%.*]] = extractelement <3 x i32> [[DOTESIMD6]], i32 0
; CHECK-NEXT: [[LOCAL_ID_X_CAST_TY:%.*]] = zext i32 [[LOCAL_ID_X]] to i64
; CHECK-NEXT: [[DOTESIMD7:%.*]] = call <3 x i32> @llvm.genx.local.size.v3i32()
; CHECK-NEXT: [[WGSIZE_X:%.*]] = extractelement <3 x i32> [[DOTESIMD7]], i32 0
; CHECK-NEXT: [[WGSIZE_X_CAST_TY:%.*]] = zext i32 [[WGSIZE_X]] to i64
; CHECK-NEXT: [[GROUP_ID_X:%.*]] = call i32 @llvm.genx.group.id.x()
; CHECK-NEXT: [[GROUP_ID_X_CAST_TY:%.*]] = zext i32 [[GROUP_ID_X]] to i64
; CHECK-NEXT: [[MUL8:%.*]] = mul i64 [[WGSIZE_X_CAST_TY]], [[GROUP_ID_X_CAST_TY]]
; CHECK-NEXT: [[ADD9:%.*]] = add i64 [[LOCAL_ID_X_CAST_TY]], [[MUL8]]
; CHECK-NEXT: [[CAST10:%.*]] = icmp ult i64 [[ADD9]], 0

; Verify that the attribute is deleted from GenX declaration
; CHECK-NOT: readnone
entry:
%0 = load i64, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32
%cmp.not.i = icmp ult i64 %0, 0
ret void
}