@@ -1059,9 +1059,10 @@ static void translateGetSurfaceIndex(CallInst &CI) {
1059
1059
// This helper function creates cast operation from GenX intrinsic return type
1060
1060
// to currently expected. Returns pointer to created cast instruction if it
1061
1061
// was created, otherwise returns NewI.
1062
- static Instruction *addCastInstIfNeeded (Instruction *OldI, Instruction *NewI) {
1062
+ static Instruction *addCastInstIfNeeded (Instruction *OldI, Instruction *NewI,
1063
+ Type *UseType = nullptr ) {
1063
1064
Type *NITy = NewI->getType ();
1064
- Type *OITy = OldI->getType ();
1065
+ Type *OITy = UseType ? UseType : OldI->getType ();
1065
1066
if (OITy != NITy) {
1066
1067
auto CastOpcode = CastInst::getCastOpcode (NewI, false , OITy, false );
1067
1068
NewI = CastInst::Create (CastOpcode, NewI, OITy,
@@ -1086,7 +1087,8 @@ static uint64_t getIndexFromExtract(ExtractElementInst *EEI) {
1086
1087
// / of vector load. The parameter \p IsVectorCall tells what version of GenX
1087
1088
// / intrinsic (scalar or vector) to use to lower the load from SPIRV global.
1088
1089
static Instruction *generateGenXCall (Instruction *EEI, StringRef IntrinName,
1089
- bool IsVectorCall, uint64_t IndexValue) {
1090
+ bool IsVectorCall, uint64_t IndexValue,
1091
+ Type *UseType) {
1090
1092
std::string Suffix =
1091
1093
IsVectorCall
1092
1094
? " .v3i32"
@@ -1125,7 +1127,7 @@ static Instruction *generateGenXCall(Instruction *EEI, StringRef IntrinName,
1125
1127
ExtractName, EEI);
1126
1128
Inst->setDebugLoc (EEI->getDebugLoc ());
1127
1129
}
1128
- Inst = addCastInstIfNeeded (EEI, Inst);
1130
+ Inst = addCastInstIfNeeded (EEI, Inst, UseType );
1129
1131
return Inst;
1130
1132
}
1131
1133
@@ -1166,37 +1168,39 @@ bool translateLLVMIntrinsic(CallInst *CI) {
1166
1168
// Generate translation instructions for SPIRV global function calls
1167
1169
static Value *generateSpirvGlobalGenX (Instruction *EEI,
1168
1170
StringRef SpirvGlobalName,
1169
- uint64_t IndexValue) {
1171
+ uint64_t IndexValue,
1172
+ Type *UseType = nullptr ) {
1170
1173
Value *NewInst = nullptr ;
1171
1174
if (SpirvGlobalName == " WorkgroupSize" ) {
1172
- NewInst = generateGenXCall (EEI, " local.size" , true , IndexValue);
1175
+ NewInst = generateGenXCall (EEI, " local.size" , true , IndexValue, UseType );
1173
1176
} else if (SpirvGlobalName == " LocalInvocationId" ) {
1174
- NewInst = generateGenXCall (EEI, " local.id" , true , IndexValue);
1177
+ NewInst = generateGenXCall (EEI, " local.id" , true , IndexValue, UseType );
1175
1178
} else if (SpirvGlobalName == " WorkgroupId" ) {
1176
- NewInst = generateGenXCall (EEI, " group.id" , false , IndexValue);
1179
+ NewInst = generateGenXCall (EEI, " group.id" , false , IndexValue, UseType );
1177
1180
} else if (SpirvGlobalName == " GlobalInvocationId" ) {
1178
1181
// GlobalId = LocalId + WorkGroupSize * GroupId
1179
- Instruction *LocalIdI = generateGenXCall (EEI, " local.id" , true , IndexValue);
1182
+ Instruction *LocalIdI =
1183
+ generateGenXCall (EEI, " local.id" , true , IndexValue, UseType);
1180
1184
Instruction *WGSizeI =
1181
- generateGenXCall (EEI, " local.size" , true , IndexValue);
1185
+ generateGenXCall (EEI, " local.size" , true , IndexValue, UseType );
1182
1186
Instruction *GroupIdI =
1183
- generateGenXCall (EEI, " group.id" , false , IndexValue);
1187
+ generateGenXCall (EEI, " group.id" , false , IndexValue, UseType );
1184
1188
Instruction *MulI =
1185
1189
BinaryOperator::CreateMul (WGSizeI, GroupIdI, " mul" , EEI);
1186
1190
NewInst = BinaryOperator::CreateAdd (LocalIdI, MulI, " add" , EEI);
1187
1191
} else if (SpirvGlobalName == " GlobalSize" ) {
1188
1192
// GlobalSize = WorkGroupSize * NumWorkGroups
1189
1193
Instruction *WGSizeI =
1190
- generateGenXCall (EEI, " local.size" , true , IndexValue);
1194
+ generateGenXCall (EEI, " local.size" , true , IndexValue, UseType );
1191
1195
Instruction *NumWGI =
1192
- generateGenXCall (EEI, " group.count" , true , IndexValue);
1196
+ generateGenXCall (EEI, " group.count" , true , IndexValue, UseType );
1193
1197
NewInst = BinaryOperator::CreateMul (WGSizeI, NumWGI, " mul" , EEI);
1194
1198
} else if (SpirvGlobalName == " GlobalOffset" ) {
1195
1199
// TODO: Support GlobalOffset SPIRV intrinsics
1196
1200
// Currently all users of load of GlobalOffset are replaced with 0.
1197
1201
NewInst = llvm::Constant::getNullValue (EEI->getType ());
1198
1202
} else if (SpirvGlobalName == " NumWorkgroups" ) {
1199
- NewInst = generateGenXCall (EEI, " group.count" , true , IndexValue);
1203
+ NewInst = generateGenXCall (EEI, " group.count" , true , IndexValue, UseType );
1200
1204
}
1201
1205
1202
1206
llvm::esimd::assert_and_diag (
@@ -1255,8 +1259,8 @@ translateSpirvGlobalUses(LoadInst *LI, StringRef SpirvGlobalName,
1255
1259
SmallVector<User *> Users (LI->users ());
1256
1260
for (User *LU : Users) {
1257
1261
Instruction *Inst = cast<Instruction>(LU);
1258
- NewInst =
1259
- generateSpirvGlobalGenX (Inst, SpirvGlobalName, /* IndexValue= */ 0 );
1262
+ NewInst = generateSpirvGlobalGenX (Inst, SpirvGlobalName, /* IndexValue= */ 0 ,
1263
+ LI-> getType () );
1260
1264
LU->replaceUsesOfWith (LI, NewInst);
1261
1265
}
1262
1266
InstsToErase.push_back (LI);
0 commit comments